From 8d0dd86c828aff1dec419b703a373d72255d588a Mon Sep 17 00:00:00 2001 From: Li Zhang Date: Wed, 19 Jul 2023 14:19:50 +0000 Subject: [PATCH 1/7] add GQA for llama2 --- examples/cpp/llama/CMakeLists.txt | 2 +- lmdeploy/serve/turbomind/deploy.py | 25 +++- .../decoder_masked_multihead_attention.h | 1 + ...er_masked_multihead_attention_template.cuh | 107 ++++++++--------- .../kernels/unfused_attention_kernels.cu | 108 +++++++++--------- .../kernels/unfused_attention_kernels.h | 73 +----------- .../llama/LlamaContextAttentionLayer.cc | 43 ++++--- .../models/llama/LlamaContextAttentionLayer.h | 8 +- .../models/llama/LlamaContextDecoder.cc | 6 +- .../models/llama/LlamaContextDecoder.h | 3 +- src/turbomind/models/llama/LlamaDecoder.cc | 6 +- src/turbomind/models/llama/LlamaDecoder.h | 7 +- .../models/llama/LlamaDecoderLayerWeight.cc | 11 +- .../models/llama/LlamaDecoderLayerWeight.h | 9 +- .../llama/LlamaDecoderSelfAttentionLayer.cc | 35 ++++-- .../llama/LlamaDecoderSelfAttentionLayer.h | 5 + src/turbomind/models/llama/LlamaV2.cc | 12 +- src/turbomind/models/llama/LlamaV2.h | 3 +- src/turbomind/models/llama/LlamaWeight.cc | 16 ++- src/turbomind/models/llama/LlamaWeight.h | 6 +- src/turbomind/models/llama/llama_kernels.cu | 43 +++++-- src/turbomind/models/llama/llama_kernels.h | 1 + .../triton_backend/llama/LlamaTritonModel.cc | 31 +++-- .../triton_backend/llama/LlamaTritonModel.h | 1 + 24 files changed, 296 insertions(+), 266 deletions(-) diff --git a/examples/cpp/llama/CMakeLists.txt b/examples/cpp/llama/CMakeLists.txt index 6f53a0aa2..a9e9bda93 100644 --- a/examples/cpp/llama/CMakeLists.txt +++ b/examples/cpp/llama/CMakeLists.txt @@ -3,6 +3,6 @@ add_executable(llama_triton_example llama_triton_example.cc) target_link_libraries(llama_triton_example PUBLIC -lcublas -lcublasLt -lcudart LlamaTritonBackend TransformerTritonBackend mpi_utils nccl_utils - nvtx_utils word_list glog) + nvtx_utils word_list) install(TARGETS llama_triton_example DESTINATION ${CMAKE_INSTALL_PREFIX}/bin) diff --git a/lmdeploy/serve/turbomind/deploy.py b/lmdeploy/serve/turbomind/deploy.py index 755a09d5d..cbb1d8eea 100644 --- a/lmdeploy/serve/turbomind/deploy.py +++ b/lmdeploy/serve/turbomind/deploy.py @@ -133,10 +133,12 @@ def save_bin(param: torch.Tensor, name): if key == 'w_qkv' and ext == 'bias': attn_bias = True copy = False - if key in ['w1', 'w3', 'w_qkv']: + if key in ['w1', 'w3']: split_dim = -1 if key == 'w1': inter_size = param_data.shape[-1] + elif key == 'w_qkv': + split_dim = 1 elif key in ['w2', 'wo']: if ext in ['scales', 'zeros', 'bias']: copy = True @@ -316,6 +318,16 @@ def permute(x: torch.Tensor): 1).transpose(1, 2).reshape(dim, 1) +def merge_qkv(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + kv_head_num: int, dim: int): + + def reshape(x): + return x.view(x.size(0), kv_head_num, -1) if dim == 2 else x.view( + kv_head_num, -1) + + return torch.cat((reshape(q), reshape(k), reshape(v)), dim=-1) + + def deploy_hf(model_name: str, model_path: str, tokenizer_path: str, triton_models_path: str, tp: int): """Deploy a model with huggingface transformers' format. @@ -349,6 +361,10 @@ def deploy_hf(model_name: str, model_path: str, tokenizer_path: str, model_arg = json.load(f) num_layer = model_arg['num_hidden_layers'] norm_eps = model_arg['rms_norm_eps'] + if 'num_key_value_heads' in model_arg: + kv_head_num = model_arg['num_key_value_heads'] + else: + kv_head_num = model_arg['num_attention_heads'] except Exception as e: print(f'get "num_hidden_layers" and "rms_norm_eps" from ' f'{params_path} failed: {e}') @@ -416,11 +432,10 @@ def get_tensor_transposed(name: str): q = permute(q) k = permute(k) if suffix == _qweight: # weight, qweight - # insert a dimension for splitting heads later - qkv = torch.stack((q, k, v), dim=1) + qkv = merge_qkv(q, k, v, kv_head_num, dim=2) + print(suffix, qkv.shape) else: # scales, zeros, bias - qkv = torch.stack((q.squeeze(), k.squeeze(), v.squeeze()), - dim=0).squeeze(dim=-1) + qkv = merge_qkv(q, k, v, kv_head_num, dim=1) print(suffix, qkv.shape) for k, v in [('w_qkv', qkv), ('wo', o)]: model_params[f'layers.{i}.attention.{k}.{suffix}'] = v diff --git a/src/turbomind/kernels/decoder_masked_multihead_attention.h b/src/turbomind/kernels/decoder_masked_multihead_attention.h index 2cf354f99..5cf502555 100644 --- a/src/turbomind/kernels/decoder_masked_multihead_attention.h +++ b/src/turbomind/kernels/decoder_masked_multihead_attention.h @@ -132,6 +132,7 @@ struct Multihead_attention_params: public Multihead_attention_params_base { T** v_cache_per_sample = nullptr; size_t kv_cache_per_sample_offset = 0; bool k_cache_interleaved = true; + int num_kv_heads = 0; }; template diff --git a/src/turbomind/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.cuh b/src/turbomind/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.cuh index 308f9af97..892c31a9b 100644 --- a/src/turbomind/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.cuh +++ b/src/turbomind/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.cuh @@ -80,8 +80,7 @@ namespace mmha { //////////////////////////////////////////////////////////////////////////////////////////////////// template -struct Qk_vec_m_ { -}; +struct Qk_vec_m_ {}; template<> struct Qk_vec_m_ { @@ -181,8 +180,7 @@ struct Qk_vec_k_<__nv_fp8_e4m3, 256> { //////////////////////////////////////////////////////////////////////////////////////////////////// template -struct K_vec_m_ { -}; +struct K_vec_m_ {}; template<> struct K_vec_m_ { @@ -263,8 +261,7 @@ struct K_vec_k_<__nv_fp8_e4m3, 1> { //////////////////////////////////////////////////////////////////////////////////////////////////// template -struct V_vec_m_ { -}; +struct V_vec_m_ {}; template<> struct V_vec_m_ { @@ -344,8 +341,7 @@ struct V_vec_k_<__nv_fp8_e4m3, 16> { #ifdef MMHA_USE_FP32_ACUM_FOR_FMA template -struct Qk_vec_acum_fp32_ { -}; +struct Qk_vec_acum_fp32_ {}; template<> struct Qk_vec_acum_fp32_ { @@ -427,8 +423,7 @@ struct Qk_vec_acum_fp32_ { //////////////////////////////////////////////////////////////////////////////////////////////////// template -struct K_vec_acum_fp32_ { -}; +struct K_vec_acum_fp32_ {}; template<> struct K_vec_acum_fp32_ { @@ -490,8 +485,7 @@ struct K_vec_acum_fp32_ { #ifdef MMHA_USE_FP32_ACUM_FOR_OUT template -struct V_vec_acum_fp32_ { -}; +struct V_vec_acum_fp32_ {}; template<> struct V_vec_acum_fp32_ { @@ -677,7 +671,7 @@ inline __device__ float4 hmma_fp32(const uint2& a, uint32_t b) " {%7, %7, %7, %7}; \n" : "=f"(c.x), "=f"(c.y), "=f"(c.z), "=f"(c.w) - : "r"(a.x) "r"(a.y), "r"(b), "f"(zero)); + : "r"(a.x), "r"(a.y), "r"(b), "f"(zero)); return c; } @@ -1349,16 +1343,18 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params if (params.finished != nullptr && params.finished[bi] == true) { return; } - // The beam idx - const int beami = bi % params.beam_width; - // The "beam-aware" batch idx - const int bbi = bi / params.beam_width; + // The head. const int hi = blockIdx.x; // Combine the batch and the head indices. const int bhi = bi * params.num_heads + hi; - // Combine the "beam-aware" batch idx and the head indices. - const int bbhi = bbi * params.beam_width * params.num_heads + hi; + + const int head_n_rep = params.num_heads / params.num_kv_heads; + + const int kvhi = hi / head_n_rep; // heads in the same group collapse to the same kv head + + const bool group_leader = hi % head_n_rep == 0; // only group leader writes to kv cache + // The thread in the block. const int tidx = threadIdx.x; @@ -1369,8 +1365,6 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params float qk = 0.0F; - int qkv_base_offset = (params.stride == 0) ? bhi * Dh : bi * params.stride + hi * Dh; - const size_t bi_seq_len_offset = bi * params.memory_max_len; const int tlength = params.length_per_sample[bi] + params.max_prefix_prompt_length; @@ -1380,10 +1374,16 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params // First QK_VECS_PER_WARP load Q and K + the bias values for the current timestep. const bool is_masked = tidx >= QK_VECS_PER_WARP; + const int q_base_offset = bi * params.stride + hi * Dh; + const int k_base_offset = bi * params.stride + kvhi * Dh; + // The offset in the Q and K buffer also accounts for the batch. - int qk_offset = qkv_base_offset + tidx * QK_VEC_SIZE; + const int q_offset = q_base_offset + tidx * QK_VEC_SIZE; + const int k_offset = k_base_offset + tidx * QK_VEC_SIZE; + // The offset in the bias buffer. - int qk_bias_offset = hi * Dh + tidx * QK_VEC_SIZE; + const int q_bias_offset = hi * Dh + tidx * QK_VEC_SIZE; + const int k_bias_offset = kvhi * Dh + tidx * QK_VEC_SIZE; // past kv quant param const float k_scale = params.attention_k_scale; @@ -1393,31 +1393,30 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params Qk_vec_k q; zero(q); if (!is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh)) { - q = vec_conversion(*reinterpret_cast(¶ms.q[qk_offset])); + q = vec_conversion(*reinterpret_cast(¶ms.q[q_offset])); } Qk_vec_k k; zero(k); { k = !is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) ? - vec_conversion(*reinterpret_cast(¶ms.k[qk_offset])) : + vec_conversion(*reinterpret_cast(¶ms.k[k_offset])) : k; } // Trigger the loads from the Q and K bias buffers. Qk_vec_k q_bias; zero(q_bias); - q_bias = - (!is_masked && Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) && params.q_bias != nullptr ? - vec_conversion(*reinterpret_cast(¶ms.q_bias[qk_bias_offset])) : - q_bias; + q_bias = (!is_masked && Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) && params.q_bias != nullptr ? + vec_conversion(*reinterpret_cast(¶ms.q_bias[q_bias_offset])) : + q_bias; Qk_vec_k k_bias; zero(k_bias); if (handle_kv) { k_bias = !is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) && params.k_bias != nullptr ? - vec_conversion(*reinterpret_cast(¶ms.k_bias[qk_bias_offset])) : + vec_conversion(*reinterpret_cast(¶ms.k_bias[k_bias_offset])) : k_bias; } @@ -1454,7 +1453,7 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params // The position of the thread in that 16B chunk. int ci = tidx % QK_VECS_IN_16B * QK_VEC_SIZE; - if (handle_kv) { + if (handle_kv && group_leader) { // Trigger the stores to global memory. if (Dh == Dh_MAX || co < Dh / QK_ELTS_IN_16B) { if (!params.k_cache_per_sample) { @@ -1476,12 +1475,12 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params else { int offset; if (params.k_cache_interleaved) { - offset = params.kv_cache_per_sample_offset + hi * params.memory_max_len * Dh + offset = params.kv_cache_per_sample_offset + kvhi * params.memory_max_len * Dh + co * params.memory_max_len * QK_ELTS_IN_16B + tlength_circ * QK_ELTS_IN_16B + ci; } else { - offset = params.kv_cache_per_sample_offset + hi * params.memory_max_len * Dh + tlength_circ * Dh - + co * QK_ELTS_IN_16B + ci; + offset = params.kv_cache_per_sample_offset + kvhi * params.memory_max_len * Dh + + tlength_circ * Dh + co * QK_ELTS_IN_16B + ci; } if (not QUANT_POLICY) { @@ -1577,7 +1576,7 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params if (not QUANT_POLICY) { k_cache_batch = params.k_cache_per_sample ? (params.k_cache_per_sample[bi] + params.kv_cache_per_sample_offset - + hi * params.memory_max_len * Dh + ki) : + + kvhi * params.memory_max_len * Dh + ki) : ¶ms.k_cache[bhi * params.memory_max_len * Dh + ki]; // Base pointer for the beam's batch, before offsetting with indirection buffer // T* k_cache_batch = ¶ms.k_cache[bbhi * params.memory_max_len * Dh + ki]; @@ -1586,7 +1585,7 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params // convert k_cache_per_sample to int8 if (params.k_cache_per_sample) { int8_t* ptr = reinterpret_cast(params.k_cache_per_sample[bi]); - k_cache_batch_int8 = ptr + params.kv_cache_per_sample_offset + hi * params.memory_max_len * Dh + ki; + k_cache_batch_int8 = ptr + params.kv_cache_per_sample_offset + kvhi * params.memory_max_len * Dh + ki; } else { int8_t* ptr = reinterpret_cast(params.k_cache); @@ -1765,7 +1764,7 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params if (not QUANT_POLICY) { v_cache = params.v_cache_per_sample ? (params.v_cache_per_sample[bi] + params.kv_cache_per_sample_offset - + hi * params.memory_max_len * Dh + vi) : + + kvhi * params.memory_max_len * Dh + vi) : ¶ms.v_cache[bhi * params.memory_max_len * Dh + vi]; // Base pointer for the beam's batch, before offsetting with indirection buffer // T* v_cache_batch = ¶ms.v_cache[bbhi * params.memory_max_len * Dh + vi]; @@ -1774,7 +1773,7 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params else if (QUANT_POLICY == 4) { if (params.v_cache_per_sample) { int8_t* ptr = reinterpret_cast(params.v_cache_per_sample[bi]); - v_cache_int8 = ptr + params.kv_cache_per_sample_offset + hi * params.memory_max_len * Dh + vi; + v_cache_int8 = ptr + params.kv_cache_per_sample_offset + kvhi * params.memory_max_len * Dh + vi; } else { int8_t* ptr = reinterpret_cast(params.v_cache); @@ -1787,22 +1786,6 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params // The number of values processed per iteration of the loop. constexpr int V_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_VALUE; - // One group of threads computes the product(s) for the current timestep. - V_vec_k v_bias; - zero(v_bias); - // if( vo == params.timestep % V_PER_ITER ) { - if (Dh == Dh_MAX || vi < Dh) { - if (handle_kv) { - if (vo == tlength % V_PER_ITER) { - // Trigger the loads from the V bias buffer. - if (params.v_bias != nullptr) { - v_bias = vec_conversion( - *reinterpret_cast(¶ms.v_bias[hi * Dh + vi])); - } - } - } - } - // From previous, before values, step // Also make sure the logits are in shared memory. __syncthreads(); @@ -1924,14 +1907,18 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params V_vec_k v; // Trigger the loads from the V buffer. - const auto v_offset = qkv_base_offset + vi; - v = vec_conversion(*reinterpret_cast(¶ms.v[v_offset])); + const auto v_offset = k_base_offset + vi; + + v = vec_conversion(*reinterpret_cast(¶ms.v[v_offset])); + // Trigger the loads from the V bias buffer. - // V_vec v_bias = *reinterpret_cast(¶ms.v_bias[hi*Dh + vi]); + if (params.v_bias != nullptr) { + V_vec_k v_bias = *reinterpret_cast(¶ms.v_bias[kvhi * Dh + vi]); + v = add(v, v_bias); + } - // Compute the V values with bias. - if (handle_kv) { - v = add(v, v_bias); + // Store the V values to cache + if (handle_kv && group_leader) { // Store the values with bias back to global memory in the cache for V. //*reinterpret_cast(&v_cache[params.timestep*Dh]) = v; diff --git a/src/turbomind/kernels/unfused_attention_kernels.cu b/src/turbomind/kernels/unfused_attention_kernels.cu index 8ec094fa3..70c2e2e60 100644 --- a/src/turbomind/kernels/unfused_attention_kernels.cu +++ b/src/turbomind/kernels/unfused_attention_kernels.cu @@ -20,6 +20,7 @@ #include "src/turbomind/kernels/unfused_attention_kernels.h" #include "src/turbomind/utils/cuda_type_utils.cuh" #include "src/turbomind/utils/cuda_utils.h" +#include "src/turbomind/utils/logger.h" namespace turbomind { @@ -1336,6 +1337,7 @@ __global__ void add_fusedQKV_bias_transpose_kernel(T* const int batch_size, const int seq_len, const int head_num, + const int kv_head_num, const int size_per_head, const int rotary_embedding_dim, const bool neox_rotary_style) @@ -1396,7 +1398,7 @@ __global__ void add_fusedQKV_bias_transpose_kernel(T* const int prefix_prompt_length = PREFIX_PROMPT ? param.d_prefix_prompt_lengths[batch_idx] : 0; const int hidden_idx = head_idx * size_per_head + tidx * vec_size; - const int n = head_num * size_per_head; + // const int n = head_num * size_per_head; // the [0..seq_len) indices really handle KV [max_pp_len..seq_len+max_pp_len) // and Q [0..seq_len) @@ -1405,33 +1407,48 @@ __global__ void add_fusedQKV_bias_transpose_kernel(T* // NOTE: q has seq len excluding prefix prompt // src QKV: [batch, time, 3, head, hidden] - const int src_q_idx = token_idx * 3 * n + hidden_idx; - const int src_k_idx = token_idx * 3 * n + hidden_idx + n; - const int src_v_idx = token_idx * 3 * n + hidden_idx + 2 * n; + // const int src_q_idx = token_idx * 3 * n + hidden_idx; + // const int src_k_idx = token_idx * 3 * n + hidden_idx + n; + // const int src_v_idx = token_idx * 3 * n + hidden_idx + 2 * n; + + const int q_kv_head_num = head_num + 2 * kv_head_num; + + const int k_offset = head_num * size_per_head; + const int v_offset = k_offset + kv_head_num * size_per_head; + + // src QKV: [batch, time, q_kv_head_num, hidden] + const int src_q_idx = token_idx * q_kv_head_num * size_per_head + hidden_idx; + const int src_k_idx = token_idx * q_kv_head_num * size_per_head + hidden_idx + k_offset; + const int src_v_idx = token_idx * q_kv_head_num * size_per_head + hidden_idx + v_offset; Vec_t q, k, v; Vec_t q_bias, k_bias, v_bias; + + // load Q and apply bias if (!is_masked) { q = *reinterpret_cast(&QKV[src_q_idx]); - k = *reinterpret_cast(&QKV[src_k_idx]); - v = *reinterpret_cast(&QKV[src_v_idx]); - if (qkv_bias) { q_bias = *reinterpret_cast(&qkv_bias[hidden_idx]); - k_bias = *reinterpret_cast(&qkv_bias[hidden_idx + n]); - v_bias = *reinterpret_cast(&qkv_bias[hidden_idx + 2 * n]); + q = mmha::add(q, q_bias); } } - if (qkv_bias) { - q = mmha::add(q, q_bias); - k = mmha::add(k, k_bias); - v = mmha::add(v, v_bias); + // load KV and apply bias + if (!is_masked && head_idx < kv_head_num) { + k = *reinterpret_cast(&QKV[src_k_idx]); + v = *reinterpret_cast(&QKV[src_v_idx]); + if (qkv_bias) { + k_bias = *reinterpret_cast(&qkv_bias[hidden_idx + k_offset]); + v_bias = *reinterpret_cast(&qkv_bias[hidden_idx + v_offset]); + k = mmha::add(k, k_bias); + v = mmha::add(v, v_bias); + } } const int t_offset = history_length ? history_length[batch_idx] : 0; if (!neox_rotary_style) { + // TODO: unused computation on k if GQA is used mmha::apply_rotary_embedding(q, k, tidx, rotary_embedding_dim, dst_kv_seq_idx + t_offset); } else { @@ -1472,23 +1489,28 @@ __global__ void add_fusedQKV_bias_transpose_kernel(T* k = *reinterpret_cast(k_smem + half_idx * smem_pitch + intra_half_idx); } } + if (!is_masked && !q_buf) { // also skip modifying QKV if q/k/v_buf are present *reinterpret_cast(&QKV[src_q_idx]) = q; - *reinterpret_cast(&QKV[src_k_idx]) = k; - *reinterpret_cast(&QKV[src_v_idx]) = v; + if (head_idx < kv_head_num) { + *reinterpret_cast(&QKV[src_k_idx]) = k; + *reinterpret_cast(&QKV[src_v_idx]) = v; + } } const int dest_q_idx = batch_idx * size_per_head * seq_len * head_num + head_idx * size_per_head * seq_len + seq_idx * size_per_head + tidx * vec_size; - const int dest_kv_idx = batch_idx * size_per_head * total_seq_len * head_num + const int dest_kv_idx = batch_idx * size_per_head * total_seq_len * kv_head_num + head_idx * size_per_head * total_seq_len + dst_kv_seq_idx * size_per_head + tidx * vec_size; if (!is_masked) { - *reinterpret_cast(&q_buf[dest_q_idx]) = q; - *reinterpret_cast(&k_buf[dest_kv_idx]) = k; - *reinterpret_cast(&v_buf[dest_kv_idx]) = v; + *reinterpret_cast(&q_buf[dest_q_idx]) = q; + if (head_idx < kv_head_num) { + *reinterpret_cast(&k_buf[dest_kv_idx]) = k; + *reinterpret_cast(&v_buf[dest_kv_idx]) = v; + } } } @@ -1504,6 +1526,7 @@ __global__ void add_fusedQKV_bias_transpose_kernel(T* batch_size, \ seq_len, \ head_num, \ + kv_head_num, \ size_per_head, \ rotary_embedding_dim, \ neox_rotary_style); @@ -1521,6 +1544,7 @@ void invokeAddFusedQKVBiasTranspose(T* q_buf, const int seq_len, const int token_num, const int head_num, + const int kv_head_num, const int size_per_head, const int rotary_embedding_dim, const int neox_rotary_style, @@ -1528,42 +1552,19 @@ void invokeAddFusedQKVBiasTranspose(T* q_buf, const int int8_mode, cudaStream_t stream) { - // [bs, seq_len, 3, head, Dh] - if (rotary_embedding_dim == 0 && param.max_prefix_prompt_length == 0) { - const int m = token_num; - const int n = head_num * size_per_head; - dim3 block(384); - dim3 grid((int)(ceil(1.0 * m * n / 384))); - add_fusedQKV_bias_transpose_kernel<<>>(q_buf, - k_buf, - v_buf, - QKV, - qkv_bias, - padding_offset, - batch_size, - seq_len, - token_num, - head_num, - size_per_head, - scale, - int8_mode); + TM_LOG_ERROR("invokeAddFusedQKVBiasTranspose"); + FT_CHECK(rotary_embedding_dim); + FT_CHECK_WITH_INFO(int8_mode != 2, "w8a8 not yet implemented with prefix prompt"); // TODO(mseznec) + // To implement rotary embeddings, each thread processes two QKV elems: + dim3 block((size_per_head / Vec_t::size + 31) / 32 * 32); + dim3 grid(token_num + batch_size * param.max_prefix_prompt_length, head_num); + size_t smem_size = neox_rotary_style ? 2 * rotary_embedding_dim * sizeof(T) : 0; + // NOTE: add offset for rotary embedding + if (param.max_prefix_prompt_length == 0) { + FUSED_QKV_BIAS_TRANSPOSE_LAUNCH(T, false); } else { - FT_CHECK_WITH_INFO(int8_mode != 2, "w8a8 not yet implemented with prefix prompt"); // TODO(mseznec) - // To implement rotary embeddings, each thread processes two QKV elems: - dim3 block((size_per_head / Vec_t::size + 31) / 32 * 32); - dim3 grid(token_num + batch_size * param.max_prefix_prompt_length, head_num); - size_t smem_size = neox_rotary_style ? 2 * rotary_embedding_dim * sizeof(T) : 0; - // NOTE: add offset for rotary embedding - // add_fusedQKV_bias_transpose_kernel<<>>( - // q_buf, k_buf, v_buf, param, QKV, qkv_bias, batch_size, seq_len, head_num, size_per_head, - // rotary_embedding_dim); - if (param.max_prefix_prompt_length == 0) { - FUSED_QKV_BIAS_TRANSPOSE_LAUNCH(T, false); - } - else { - FUSED_QKV_BIAS_TRANSPOSE_LAUNCH(T, true); - } + FUSED_QKV_BIAS_TRANSPOSE_LAUNCH(T, true); } } @@ -1580,6 +1581,7 @@ void invokeAddFusedQKVBiasTranspose(T* q_buf, const int seq_len, \ const int token_num, \ const int head_num, \ + const int kv_head_num, \ const int size_per_head, \ const int rotary_embedding_dim, \ const int neox_rotary_style, \ diff --git a/src/turbomind/kernels/unfused_attention_kernels.h b/src/turbomind/kernels/unfused_attention_kernels.h index d804dace0..51b33a287 100644 --- a/src/turbomind/kernels/unfused_attention_kernels.h +++ b/src/turbomind/kernels/unfused_attention_kernels.h @@ -112,6 +112,7 @@ struct PrefixPromptBatchWeightsParam { // l * 2 * hidden_units_ / tensor_para_.world_size_ const size_t prefix_prompt_layer_offset_per_seq = 0; }; + template void invokeAddFusedQKVBiasTranspose(T* q_buf, T* k_buf, @@ -125,83 +126,13 @@ void invokeAddFusedQKVBiasTranspose(T* q_buf, const int seq_len, const int token_num, const int head_num, + const int kv_head_num, const int size_per_head, const int rotary_embedding_dim, const int neox_rotary_style, const float* scale, const int int8_mode, cudaStream_t stream); -template -void invokeAddFusedQKVBiasTranspose(T* q_buf, - T* k_buf, - T* v_buf, - PrefixPromptBatchWeightsParam param, - T* QKV, - const T* qkv_bias, - const int* padding_offset, - const int batch_size, - const int seq_len, - const int token_num, - const int head_num, - const int size_per_head, - const int rotary_embedding_dim, - const int neox_rotary_style, - const float* scale, - const int int8_mode, - cudaStream_t stream) -{ - invokeAddFusedQKVBiasTranspose(q_buf, - k_buf, - v_buf, - param, - QKV, - qkv_bias, - padding_offset, - nullptr, - batch_size, - seq_len, - token_num, - head_num, - size_per_head, - rotary_embedding_dim, - neox_rotary_style, - scale, - int8_mode, - stream); -} - -template -void invokeAddFusedQKVBiasTranspose(T* q_buf, - T* k_buf, - T* v_buf, - T* QKV, - const T* qkv_bias, - const int* padding_offset, - const int batch_size, - const int seq_len, - const int token_num, - const int head_num, - const int size_per_head, - cudaStream_t stream) -{ - invokeAddFusedQKVBiasTranspose(q_buf, - k_buf, - v_buf, - PrefixPromptBatchWeightsParam{}, - QKV, - qkv_bias, - padding_offset, - batch_size, - seq_len, - token_num, - head_num, - size_per_head, - 0, - false, - (float*)nullptr, - 0, - stream); -} template void invokeTranspose4d(T* dst, diff --git a/src/turbomind/models/llama/LlamaContextAttentionLayer.cc b/src/turbomind/models/llama/LlamaContextAttentionLayer.cc index 1e872a59e..806053c55 100644 --- a/src/turbomind/models/llama/LlamaContextAttentionLayer.cc +++ b/src/turbomind/models/llama/LlamaContextAttentionLayer.cc @@ -27,6 +27,7 @@ #include "src/turbomind/models/llama/llama_utils.h" #include "src/turbomind/utils/Tensor.h" #include "src/turbomind/utils/cuda_utils.h" +#include "src/turbomind/utils/logger.h" namespace turbomind { @@ -38,13 +39,17 @@ void LlamaContextAttentionLayer::allocateBuffer(size_t batch_size, { TM_LOG_DEBUG(__PRETTY_FUNCTION__); + const int local_q_kv_head_num = local_head_num_ + 2 * local_kv_head_num_; + // no padding - qkv_buf_ = (T*)allocator_->reMalloc(qkv_buf_, sizeof(T) * num_token * 3 * local_hidden_units_, true); + qkv_buf_ = (T*)allocator_->reMalloc(qkv_buf_, sizeof(T) * num_token * local_q_kv_head_num * size_per_head_, true); // padding is rebuilt for q/k/v_buf_2_ - q_buf_2_ = (T*)allocator_->reMalloc(q_buf_2_, sizeof(T) * 3 * batch_size * max_q_len * local_hidden_units_, true); - k_buf_2_ = q_buf_2_ + batch_size * max_q_len * local_hidden_units_; - v_buf_2_ = k_buf_2_ + batch_size * max_q_len * local_hidden_units_; + // [qH + 2kvH, B, S, D] + q_buf_2_ = (T*)allocator_->reMalloc( + q_buf_2_, sizeof(T) * local_q_kv_head_num * batch_size * max_q_len * size_per_head_, true); + k_buf_2_ = q_buf_2_ + local_head_num_ * batch_size * max_q_len * size_per_head_; + v_buf_2_ = k_buf_2_ + local_kv_head_num_ * batch_size * max_q_len * size_per_head_; if (use_fmha_) { FlashAttentionOp flash_attention(batch_size, local_head_num_, max_k_len, max_q_len, size_per_head_); @@ -54,19 +59,19 @@ void LlamaContextAttentionLayer::allocateBuffer(size_t batch_size, } else { k_cache_buf_ = (T*)allocator_->reMalloc( - k_cache_buf_, 2 * sizeof(T) * batch_size * local_head_num_ * max_k_len * size_per_head_, true); - v_cache_buf_ = k_cache_buf_ + batch_size * local_head_num_ * max_k_len * size_per_head_; + k_cache_buf_, 2 * sizeof(T) * batch_size * local_kv_head_num_ * max_k_len * size_per_head_, true); + v_cache_buf_ = k_cache_buf_ + batch_size * local_kv_head_num_ * max_k_len * size_per_head_; qk_buf_ = (T*)allocator_->reMalloc(qk_buf_, sizeof(T) * batch_size * local_head_num_ * max_q_len * max_k_len, true); // qkv_buf_2_ has padding - qkv_buf_2_ = - (T*)allocator_->reMalloc(qkv_buf_2_, sizeof(T) * batch_size * max_q_len * local_hidden_units_, true); + qkv_buf_2_ = (T*)allocator_->reMalloc( + qkv_buf_2_, sizeof(T) * batch_size * max_q_len * local_head_num_ * size_per_head_, true); } // qkv_buf_3_ padding is removed - qkv_buf_3_ = (T*)allocator_->reMalloc(qkv_buf_3_, sizeof(T) * num_token * local_hidden_units_, true); + qkv_buf_3_ = (T*)allocator_->reMalloc(qkv_buf_3_, sizeof(T) * num_token * local_head_num_ * size_per_head_, true); is_allocate_buffer_ = true; } @@ -152,7 +157,7 @@ inline void LlamaContextAttentionLayer::forward(TensorMap* ////////////////////////////////////////////// /// transpose qkv & apply rotary embedding & rebuild padding - /// qkv [B, s, 3, H, D] -> (q [B, H, s, D], k [B, H, s, D], v [B, H, s, D]) + /// qkv [B, s, H + 2kvH, D] -> (q [B, H, s, D], k [B, kvH, s, D], v [B, kvH, s, D]) invokeAddFusedQKVBiasTranspose(q_buf_2_, k_buf_2_, v_buf_2_, @@ -165,6 +170,7 @@ inline void LlamaContextAttentionLayer::forward(TensorMap* max_q_len, // seq_len num_token, // batch_size * seq_len local_head_num_, + local_kv_head_num_, size_per_head_, rotary_embedding_dim_, neox_rotary_style_, @@ -173,16 +179,16 @@ inline void LlamaContextAttentionLayer::forward(TensorMap* stream_); sync_check_cuda_error(); - const size_t layer_offset = layer_id * local_head_num_ * max_seq_len * size_per_head_; + const size_t layer_offset = layer_id * local_kv_head_num_ * max_seq_len * size_per_head_; auto k_cache_ptrs = output_tensors->getPtr("key_cache"); auto v_cache_ptrs = output_tensors->getPtr("value_cache"); ////////////////////////////////////////////////////////// /// insert the k/v computed from inputs into k/v cache /// transpose kv -> kv cache - // put k/v_buf from shape [B, H, s, D] to - // k_buf_2 [B, H, s, D] -> key_cache [B, H, S[t:t+s], D/x, x] - // v_buf_2 [B, H, s, D] -> val_cache [B, H, S[t:t+s], D/x, x] + // put k/v_buf from shape [B, kvH, s, D] to + // k_buf_2 [B, kvH, s, D] -> key_cache [B, kvH, S[t:t+s], D/x, x] + // v_buf_2 [B, kvH, s, D] -> val_cache [B, kvH, S[t:t+s], D/x, x] invokeExtendKVCache(k_cache_ptrs, v_cache_ptrs, layer_offset, @@ -194,13 +200,14 @@ inline void LlamaContextAttentionLayer::forward(TensorMap* history_length, max_seq_len, size_per_head_, - local_head_num_, + local_kv_head_num_, stream_, quant_policy_, weights->past_kv_scale.data()); sync_check_cuda_error(); if (use_fmha_) { + FT_CHECK(local_head_num_ == local_kv_head_num_); fusedMultiHeadAttention(k_cache_ptrs, v_cache_ptrs, layer_offset, @@ -311,8 +318,9 @@ void LlamaContextAttentionLayer::unfusedMultiHeadAttention(T** key_c int quant, const float* kv_scale) { - // key_cache [B, H, S[:t+s], D/x, x] -> [B, H, t+s, D] - // val_cache [B, H, S[:t+s], D/x, x] -> [B, H, t+s, D] + TM_LOG_ERROR("[LlamaContextAttentionLayer] head_n_rep=%d", (int)head_n_rep_); + // key_cache [B, kvH, S[:t+s], D/x, x] -> [B, qH, t+s, D] + // val_cache [B, kvH, S[:t+s], D/x, x] -> [B, qH, t+s, D] invokeTransposeKVCache(k_cache_buf_, v_cache_buf_, (const T**)key_cache_ptrs, @@ -324,6 +332,7 @@ void LlamaContextAttentionLayer::unfusedMultiHeadAttention(T** key_c max_seq_len, size_per_head_, local_head_num_, + head_n_rep_, stream_, quant, kv_scale); diff --git a/src/turbomind/models/llama/LlamaContextAttentionLayer.h b/src/turbomind/models/llama/LlamaContextAttentionLayer.h index 960069464..6a90c9c5d 100644 --- a/src/turbomind/models/llama/LlamaContextAttentionLayer.h +++ b/src/turbomind/models/llama/LlamaContextAttentionLayer.h @@ -35,6 +35,7 @@ class LlamaContextAttentionLayer { void allocateBuffer(size_t batch_size, size_t num_token, size_t max_q_len, size_t max_kv_len); LlamaContextAttentionLayer(size_t head_num, + size_t kv_head_num, size_t size_per_head, size_t rotary_embedding_dim, bool neox_rotary_style, @@ -49,7 +50,8 @@ class LlamaContextAttentionLayer { size_per_head_(size_per_head), hidden_units_(head_num * size_per_head), local_head_num_(head_num / tensor_para.world_size_), - local_hidden_units_(hidden_units_ / tensor_para.world_size_), + local_kv_head_num_(kv_head_num / tensor_para.world_size_), + head_n_rep_(head_num / kv_head_num), rotary_embedding_dim_(rotary_embedding_dim), neox_rotary_style_(neox_rotary_style), tensor_para_(tensor_para), @@ -61,6 +63,7 @@ class LlamaContextAttentionLayer { use_fmha_(use_fmha), quant_policy_(quant_policy) { + FT_CHECK(head_num % kv_head_num == 0); } void forward(TensorMap* output_tensors, const TensorMap* input_tensors, const LlamaAttentionWeight* weights); @@ -93,8 +96,9 @@ class LlamaContextAttentionLayer { const size_t head_num_; const size_t size_per_head_; const size_t hidden_units_; + const size_t local_kv_head_num_; const size_t local_head_num_; - const size_t local_hidden_units_; + const size_t head_n_rep_; const size_t rotary_embedding_dim_; const bool is_free_buffer_after_forward_; diff --git a/src/turbomind/models/llama/LlamaContextDecoder.cc b/src/turbomind/models/llama/LlamaContextDecoder.cc index 7c8cf2e5e..bdefad774 100644 --- a/src/turbomind/models/llama/LlamaContextDecoder.cc +++ b/src/turbomind/models/llama/LlamaContextDecoder.cc @@ -62,11 +62,12 @@ void LlamaContextDecoder::freeBuffer() } template -void LlamaContextDecoder::initialize(bool use_fmha, int quant_policy) +void LlamaContextDecoder::initialize(size_t kv_head_num, bool use_fmha, int quant_policy) { h_pinned_token_num_ptr_ = (size_t*)allocator_->reMalloc(h_pinned_token_num_ptr_, sizeof(size_t), true, true); context_attention_layer_ = new LlamaContextAttentionLayer(head_num_, + kv_head_num, size_per_head_, rotary_embedding_dim_, false, // neox_rotary_style @@ -124,6 +125,7 @@ void LlamaContextDecoder::forwardSelfAttn(const Session& template LlamaContextDecoder::LlamaContextDecoder(size_t head_num, + size_t kv_head_num, size_t size_per_head, size_t inter_size, size_t num_layer, @@ -147,7 +149,7 @@ LlamaContextDecoder::LlamaContextDecoder(size_t head_num, tensor_para_(tensor_para), data_type_(getTensorType()) { - initialize(use_fmha, quant_policy); + initialize(kv_head_num, use_fmha, quant_policy); } template diff --git a/src/turbomind/models/llama/LlamaContextDecoder.h b/src/turbomind/models/llama/LlamaContextDecoder.h index 08200e3d3..8b5d4cfab 100644 --- a/src/turbomind/models/llama/LlamaContextDecoder.h +++ b/src/turbomind/models/llama/LlamaContextDecoder.h @@ -43,7 +43,7 @@ class LlamaContextDecoder: public BaseLayer { void allocateBuffer(size_t batch_size, size_t num_token, size_t max_q_len, size_t max_kv_len); void freeBuffer() override; - void initialize(bool use_fmha, int quant_policy); + void initialize(size_t kv_head_num, bool use_fmha, int quant_policy); size_t head_num_; size_t size_per_head_; @@ -88,6 +88,7 @@ class LlamaContextDecoder: public BaseLayer { public: LlamaContextDecoder(size_t head_num, + size_t kv_head_num, size_t size_per_head, size_t inter_size, size_t num_layer, diff --git a/src/turbomind/models/llama/LlamaDecoder.cc b/src/turbomind/models/llama/LlamaDecoder.cc index c98d1cb10..7a117c25b 100644 --- a/src/turbomind/models/llama/LlamaDecoder.cc +++ b/src/turbomind/models/llama/LlamaDecoder.cc @@ -28,6 +28,7 @@ namespace turbomind { template LlamaDecoder::LlamaDecoder(size_t head_num, + size_t kv_head_num, size_t size_per_head, size_t inter_size, size_t num_layer, @@ -51,7 +52,7 @@ LlamaDecoder::LlamaDecoder(size_t head_num, data_type_(getTensorType()) { TM_LOG_DEBUG(__PRETTY_FUNCTION__); - initialize(quant_policy); + initialize(kv_head_num, quant_policy); } template @@ -63,11 +64,12 @@ LlamaDecoder::~LlamaDecoder() } template -void LlamaDecoder::initialize(int quant_policy) +void LlamaDecoder::initialize(size_t kv_head_num, int quant_policy) { TM_LOG_DEBUG(__PRETTY_FUNCTION__); self_attention_layer_ = new LlamaDecoderSelfAttentionLayer(head_num_, + kv_head_num, size_per_head_, rotary_embedding_dim_, false, // neox_rotary_style diff --git a/src/turbomind/models/llama/LlamaDecoder.h b/src/turbomind/models/llama/LlamaDecoder.h index 5d90b69ed..a5065091b 100644 --- a/src/turbomind/models/llama/LlamaDecoder.h +++ b/src/turbomind/models/llama/LlamaDecoder.h @@ -35,7 +35,7 @@ class LlamaDecoder: public BaseLayer { void allocateBuffer() override; // deprecated void allocateBuffer(size_t batch_size); void freeBuffer() override; - void initialize(int quant_policy); + void initialize(size_t kv_head_num, int quant_policy); size_t head_num_; size_t size_per_head_; @@ -70,6 +70,7 @@ class LlamaDecoder: public BaseLayer { public: LlamaDecoder(size_t head_num, + size_t kv_head_num, size_t size_per_head, size_t inter_size, size_t num_layer, @@ -80,9 +81,9 @@ class LlamaDecoder: public BaseLayer { cublasMMWrapper* cublas_wrapper, IAllocator* allocator, bool is_free_buffer_after_forward, - int quant_policy), + int quant_policy); - ~LlamaDecoder() override; + ~LlamaDecoder() override; virtual void forward(std::unordered_map* output_tensors, const std::unordered_map* input_tensors, diff --git a/src/turbomind/models/llama/LlamaDecoderLayerWeight.cc b/src/turbomind/models/llama/LlamaDecoderLayerWeight.cc index b5566b22e..48b43f8cd 100644 --- a/src/turbomind/models/llama/LlamaDecoderLayerWeight.cc +++ b/src/turbomind/models/llama/LlamaDecoderLayerWeight.cc @@ -25,13 +25,18 @@ namespace turbomind { template -LlamaDecoderLayerWeight::LlamaDecoderLayerWeight(size_t hidden_units, +LlamaDecoderLayerWeight::LlamaDecoderLayerWeight(size_t head_num, + size_t kv_head_num, + size_t size_per_head, size_t inter_size, WeightType weight_type, bool attn_bias, size_t tensor_para_size, size_t tensor_para_rank): - hidden_units_(hidden_units), + head_num_(head_num), + kv_head_num_(kv_head_num), + size_per_head_(size_per_head), + hidden_units_(head_num * size_per_head), inter_size_(inter_size), weight_type_(weight_type), attn_bias_(attn_bias), @@ -39,7 +44,7 @@ LlamaDecoderLayerWeight::LlamaDecoderLayerWeight(size_t hidden_units, tensor_para_rank_(tensor_para_rank) { self_attn_weights.qkv.input_dims = hidden_units_; - self_attn_weights.qkv.output_dims = 3 * hidden_units_ / tensor_para_size_; + self_attn_weights.qkv.output_dims = (head_num + 2 * kv_head_num) * size_per_head / tensor_para_size_; self_attn_weights.qkv.type = weight_type; self_attn_weights.output.input_dims = hidden_units_ / tensor_para_size_; diff --git a/src/turbomind/models/llama/LlamaDecoderLayerWeight.h b/src/turbomind/models/llama/LlamaDecoderLayerWeight.h index 1c3e6f599..bc0ada5dc 100644 --- a/src/turbomind/models/llama/LlamaDecoderLayerWeight.h +++ b/src/turbomind/models/llama/LlamaDecoderLayerWeight.h @@ -28,14 +28,16 @@ template struct LlamaDecoderLayerWeight { public: LlamaDecoderLayerWeight() = delete; - LlamaDecoderLayerWeight(size_t hidden_units, + LlamaDecoderLayerWeight(size_t head_num, + size_t kv_head_num, + size_t size_per_head, size_t inter_size, WeightType weight_type, bool attn_bias, size_t tensor_para_size, size_t tensor_para_rank); ~LlamaDecoderLayerWeight(); - LlamaDecoderLayerWeight(const LlamaDecoderLayerWeight& other) = delete; + LlamaDecoderLayerWeight(const LlamaDecoderLayerWeight& other) = delete; LlamaDecoderLayerWeight& operator=(const LlamaDecoderLayerWeight& other) = delete; void loadModel(std::string dir_path, FtCudaDataType model_file_type); @@ -46,6 +48,9 @@ struct LlamaDecoderLayerWeight { LlamaFfnWeight ffn_weights{}; private: + size_t head_num_; + size_t kv_head_num_; + size_t size_per_head_; size_t hidden_units_; size_t inter_size_; WeightType weight_type_; diff --git a/src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.cc b/src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.cc index 599bf3a68..3909f4f66 100644 --- a/src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.cc +++ b/src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.cc @@ -23,6 +23,7 @@ #include "src/turbomind/models/llama/llama_kernels.h" #include "src/turbomind/models/llama/llama_utils.h" #include "src/turbomind/utils/cuda_utils.h" +#include "src/turbomind/utils/logger.h" #include "src/turbomind/utils/nvtx_utils.h" #include // #include @@ -56,6 +57,7 @@ static inline void fusedQKV_masked_attention_dispatch(const T* qkv_buf, const int inference_batch_size, const int beam_width, const int head_num, + const int kv_head_num, const int size_per_head, const int rotary_embedding_dim, const int memory_max_len, @@ -81,11 +83,11 @@ static inline void fusedQKV_masked_attention_dispatch(const T* qkv_buf, // Prepare the parameters. Masked_multihead_attention_params params; memset(¶ms, 0, sizeof(params)); - int hidden_units = head_num * size_per_head; + // int hidden_units = head_num * size_per_head; if (qkv_bias != nullptr) { params.q_bias = reinterpret_cast(qkv_bias); - params.k_bias = reinterpret_cast(qkv_bias) + hidden_units; - params.v_bias = reinterpret_cast(qkv_bias) + 2 * hidden_units; + params.k_bias = reinterpret_cast(qkv_bias) + head_num * size_per_head; + params.v_bias = reinterpret_cast(qkv_bias) + (head_num + kv_head_num) * size_per_head; } else { params.q_bias = nullptr; @@ -97,13 +99,16 @@ static inline void fusedQKV_masked_attention_dispatch(const T* qkv_buf, params.out = reinterpret_cast(context_buf); // Set the input buffers. + // [B, nH + kvH, D] params.q = reinterpret_cast(qkv_buf); - params.k = reinterpret_cast(qkv_buf) + hidden_units; - params.v = reinterpret_cast(qkv_buf) + 2 * hidden_units; + params.k = reinterpret_cast(qkv_buf) + head_num * size_per_head; + params.v = reinterpret_cast(qkv_buf) + (head_num + kv_head_num) * size_per_head; - params.stride = 3 * hidden_units; + params.stride = (head_num + 2 * kv_head_num) * size_per_head; params.finished = const_cast(finished); + FT_CHECK(k_cache_per_sample && v_cache_per_sample); + params.k_cache = reinterpret_cast(key_cache); params.v_cache = reinterpret_cast(value_cache); params.k_cache_per_sample = reinterpret_cast(k_cache_per_sample); @@ -118,8 +123,10 @@ static inline void fusedQKV_masked_attention_dispatch(const T* qkv_buf, params.max_prefix_prompt_length = max_prefix_prompt_length; params.length_per_sample = sequence_lengths; // max_input_length + current output length // timestep adding max_prefix_prompt_length for shared memory size calculation and rotary embedding computation - params.timestep = step + max_prefix_prompt_length - 1; - params.num_heads = head_num; + params.timestep = step + max_prefix_prompt_length - 1; + params.num_heads = head_num; + params.num_kv_heads = kv_head_num; + params.hidden_size_per_head = size_per_head; params.rotary_embedding_dim = rotary_embedding_dim; // Note: keep norm factor (sqrt(K_dim)) when adopting megatron T5 structure (may adjust) @@ -158,8 +165,11 @@ template void LlamaDecoderSelfAttentionLayer::allocateBuffer(size_t batch_size, int key_len, int max_memory_len) { TM_LOG_DEBUG(__PRETTY_FUNCTION__); - qkv_buf_ = - reinterpret_cast(allocator_->reMalloc(qkv_buf_, sizeof(T) * batch_size * 3 * local_hidden_units_, false)); + + const size_t local_q_kv_head_num = local_head_num_ + 2 * local_kv_head_num_; + + qkv_buf_ = reinterpret_cast( + allocator_->reMalloc(qkv_buf_, sizeof(T) * batch_size * local_q_kv_head_num * size_per_head_, false)); context_buf_ = reinterpret_cast(allocator_->reMalloc(context_buf_, sizeof(T) * batch_size * local_hidden_units_, false)); @@ -197,7 +207,7 @@ void LlamaDecoderSelfAttentionLayer::forward(TensorMap* o * * output tensors: * \param attention_output [batch_size, hidden_units], - * \param key_cache [batch, local_head_num, size_per_head / x, memory_max_len, x] + * \param key_cache [batch, local_head_num, memory_max_len, size_per_head] * \param value_cache [batch, local_head_num, memory_max_len, size_per_head] */ @@ -228,7 +238,7 @@ void LlamaDecoderSelfAttentionLayer::forward(TensorMap* o linear_.forward(qkv_buf_, input_query_data, batch_size, weights->qkv); POP_RANGE; - const auto kv_cache_layer_offset = layer_id * local_head_num_ * max_seq_len * size_per_head_; + const auto kv_cache_layer_offset = layer_id * local_kv_head_num_ * max_seq_len * size_per_head_; const int memory_len = max_seq_len; fusedQKV_masked_attention_dispatch( @@ -248,6 +258,7 @@ void LlamaDecoderSelfAttentionLayer::forward(TensorMap* o batch_size, beam_width, local_head_num_, + local_kv_head_num_, size_per_head_, rotary_embedding_dim_, memory_len, diff --git a/src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.h b/src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.h index 3f7098cd2..78e439243 100644 --- a/src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.h +++ b/src/turbomind/models/llama/LlamaDecoderSelfAttentionLayer.h @@ -34,6 +34,7 @@ class LlamaDecoderSelfAttentionLayer { void allocateBuffer(size_t batch_size, int key_len, int max_memory_len); LlamaDecoderSelfAttentionLayer(size_t head_num, + size_t kv_head_num, size_t size_per_head, size_t rotary_embedding_dim, bool neox_rotary_style, @@ -44,9 +45,11 @@ class LlamaDecoderSelfAttentionLayer { bool is_free_buffer_after_forward, int quant_policy): head_num_(head_num), + kv_head_num_(kv_head_num), size_per_head_(size_per_head), hidden_units_(head_num * size_per_head), local_head_num_(head_num / tensor_para.world_size_), + local_kv_head_num_(kv_head_num_ / tensor_para.world_size_), local_hidden_units_(hidden_units_ / tensor_para.world_size_), rotary_embedding_dim_(rotary_embedding_dim), neox_rotary_style_(neox_rotary_style), @@ -68,9 +71,11 @@ class LlamaDecoderSelfAttentionLayer { private: const size_t head_num_; + const size_t kv_head_num_; const size_t size_per_head_; const size_t hidden_units_; const size_t local_head_num_; + const size_t local_kv_head_num_; const size_t local_hidden_units_; const size_t rotary_embedding_dim_; const bool is_free_buffer_after_forward_; diff --git a/src/turbomind/models/llama/LlamaV2.cc b/src/turbomind/models/llama/LlamaV2.cc index c6ca65fba..2fd2bab38 100644 --- a/src/turbomind/models/llama/LlamaV2.cc +++ b/src/turbomind/models/llama/LlamaV2.cc @@ -39,6 +39,7 @@ namespace turbomind { template LlamaV2::LlamaV2(size_t head_num, + size_t kv_head_num, size_t size_per_head, size_t inter_size, size_t num_layer, @@ -102,8 +103,11 @@ LlamaV2::LlamaV2(size_t head_num, else { elem_bits = sizeof(T) * 8; } + + const size_t local_kv_head_num = kv_head_num / tensor_para.world_size_; + kv_cache_mgr_ = std::make_unique(num_layer_, - local_head_num_, + local_kv_head_num, size_per_head_, session_len, elem_bits, @@ -111,7 +115,7 @@ LlamaV2::LlamaV2(size_t head_num, cache_chunk_size, tensor_para.rank_, allocator); - initialize(use_context_fmha, quant_policy); + initialize(kv_head_num, use_context_fmha, quant_policy); start(); } @@ -126,11 +130,12 @@ LlamaV2::~LlamaV2() } template -void LlamaV2::initialize(bool use_context_fmha, int quant_policy) +void LlamaV2::initialize(size_t kv_head_num, bool use_context_fmha, int quant_policy) { TM_LOG_DEBUG(__PRETTY_FUNCTION__); context_decoder_ = new LlamaContextDecoder(head_num_, + kv_head_num, size_per_head_, inter_size_, num_layer_, @@ -145,6 +150,7 @@ void LlamaV2::initialize(bool use_context_fmha, int quant_policy) quant_policy); decoder_ = new LlamaDecoder(head_num_, + kv_head_num, size_per_head_, inter_size_, num_layer_, diff --git a/src/turbomind/models/llama/LlamaV2.h b/src/turbomind/models/llama/LlamaV2.h index eec815558..0b5f55e9a 100644 --- a/src/turbomind/models/llama/LlamaV2.h +++ b/src/turbomind/models/llama/LlamaV2.h @@ -49,6 +49,7 @@ class LlamaV2 { ~LlamaV2(); LlamaV2(size_t head_num, + size_t kv_head_num, size_t size_per_head, size_t inter_size, size_t num_layer, @@ -90,7 +91,7 @@ class LlamaV2 { void internalThreadEntry(int device_id); - void initialize(bool use_context_fmha, int quant_policy); + void initialize(size_t kv_head_num, bool use_context_fmha, int quant_policy); void embeddingLookup(T* embeddings, const int* token_ids_buf, int batch_size, int step); diff --git a/src/turbomind/models/llama/LlamaWeight.cc b/src/turbomind/models/llama/LlamaWeight.cc index 9ac566d58..cc3a0ddcf 100644 --- a/src/turbomind/models/llama/LlamaWeight.cc +++ b/src/turbomind/models/llama/LlamaWeight.cc @@ -23,7 +23,9 @@ namespace turbomind { template -LlamaWeight::LlamaWeight(size_t hidden_units, +LlamaWeight::LlamaWeight(size_t head_num, + size_t kv_head_num, + size_t size_per_head, size_t inter_size, size_t vocab_size, size_t num_layer, @@ -32,7 +34,7 @@ LlamaWeight::LlamaWeight(size_t hidden_units, size_t tensor_para_size, size_t tensor_para_rank, int prefix_cache_len): - hidden_units_(hidden_units), + hidden_units_(head_num * size_per_head), inter_size_(inter_size), vocab_size_(vocab_size), num_layer_(num_layer), @@ -43,8 +45,14 @@ LlamaWeight::LlamaWeight(size_t hidden_units, { decoder_layer_weights.reserve(num_layer_); for (unsigned l = 0; l < num_layer_; ++l) { - decoder_layer_weights.push_back(new LlamaDecoderLayerWeight( - hidden_units_, inter_size_, weight_type_, attn_bias, tensor_para_size_, tensor_para_rank_)); + decoder_layer_weights.push_back(new LlamaDecoderLayerWeight(head_num, + kv_head_num, + size_per_head, + inter_size_, + weight_type_, + attn_bias, + tensor_para_size_, + tensor_para_rank_)); } mallocWeights(); diff --git a/src/turbomind/models/llama/LlamaWeight.h b/src/turbomind/models/llama/LlamaWeight.h index 898edae64..dc31a5164 100644 --- a/src/turbomind/models/llama/LlamaWeight.h +++ b/src/turbomind/models/llama/LlamaWeight.h @@ -28,7 +28,9 @@ namespace turbomind { template struct LlamaWeight { LlamaWeight() = default; - LlamaWeight(size_t hidden_units, + LlamaWeight(size_t head_num, + size_t kv_head_num, + size_t size_per_head, size_t inter_size, size_t vocab_size, size_t num_layer, @@ -40,7 +42,7 @@ struct LlamaWeight { ~LlamaWeight(); - LlamaWeight(const LlamaWeight& other) = delete; + LlamaWeight(const LlamaWeight& other) = delete; LlamaWeight& operator=(const LlamaWeight& other) = delete; void loadModel(std::string dir_path); diff --git a/src/turbomind/models/llama/llama_kernels.cu b/src/turbomind/models/llama/llama_kernels.cu index db089f539..9de021cff 100644 --- a/src/turbomind/models/llama/llama_kernels.cu +++ b/src/turbomind/models/llama/llama_kernels.cu @@ -488,6 +488,7 @@ __global__ void transpose_value_cache(T* v_dst, // const T** v_src, const size_t src_offset, const int head_num, + const int head_n_rep, const int size_per_head, const int* seq_length, const int max_kv_len, @@ -511,9 +512,9 @@ __global__ void transpose_value_cache(T* v_dst, // if (v_seq_len_id < seq_len) { // [B, H, s, D/x] <- [B, H, S[:s], D/x] - const int64_t src_idx = head_id * size_per_head_div_x * max_seq_len + // H - v_seq_len_id * size_per_head_div_x + // s - v_head_size_id; // D/x + const int64_t src_idx = head_id / head_n_rep * size_per_head_div_x * max_seq_len + // H + v_seq_len_id * size_per_head_div_x + // s + v_head_size_id; // D/x const int64_t dst_idx = batch_id * head_num * size_per_head_div_x * max_kv_len + // B head_id * size_per_head_div_x * max_kv_len + // H @@ -529,6 +530,7 @@ __global__ void transpose_value_cache_int8(T* v_dst, // const int8_t** v_src, const size_t src_offset, const int head_num, + const int head_n_rep, const int size_per_head, const int* seq_length, const int max_kv_len, @@ -553,9 +555,9 @@ __global__ void transpose_value_cache_int8(T* v_dst, // if (v_seq_len_id < seq_len) { // [B, H, s, D/x] <- [B, H, S[:s], D/x] - const int64_t src_idx = head_id * size_per_head_div_x * max_seq_len + // H - v_seq_len_id * size_per_head_div_x + // s - v_head_size_id; // D/x + const int64_t src_idx = head_id / head_n_rep * size_per_head_div_x * max_seq_len + // H + v_seq_len_id * size_per_head_div_x + // s + v_head_size_id; // D/x const int64_t dst_idx = batch_id * head_num * size_per_head_div_x * max_kv_len + // B head_id * size_per_head_div_x * max_kv_len + // H @@ -583,6 +585,7 @@ void invokeTransposeKVCache(T* key_cache_trans, int max_seq_len, int size_per_head, int head_num, + int head_n_rep, cudaStream_t stream, int quant, const float* kv_scale) @@ -597,6 +600,7 @@ void invokeTransposeKVCache(T* key_cache_trans, reinterpret_cast(key_cache), src_offset, head_num, + head_n_rep, size_per_head, key_length, max_kv_len, @@ -607,6 +611,7 @@ void invokeTransposeKVCache(T* key_cache_trans, reinterpret_cast(val_cache), src_offset, head_num, + head_n_rep, size_per_head, key_length, max_kv_len, @@ -614,11 +619,25 @@ void invokeTransposeKVCache(T* key_cache_trans, kv_scale[1]); } else { - transpose_value_cache<<>>( - key_cache_trans, key_cache, src_offset, head_num, size_per_head, key_length, max_kv_len, max_seq_len); - - transpose_value_cache<<>>( - val_cache_trans, val_cache, src_offset, head_num, size_per_head, key_length, max_kv_len, max_seq_len); + transpose_value_cache<<>>(key_cache_trans, + key_cache, + src_offset, + head_num, + head_n_rep, + size_per_head, + key_length, + max_kv_len, + max_seq_len); + + transpose_value_cache<<>>(val_cache_trans, + val_cache, + src_offset, + head_num, + head_n_rep, + size_per_head, + key_length, + max_kv_len, + max_seq_len); } } @@ -633,6 +652,7 @@ template void invokeTransposeKVCache(float*, int, int, int, + int, cudaStream_t stream, int, const float*); @@ -647,6 +667,7 @@ template void invokeTransposeKVCache(half*, int, int, int, + int, cudaStream_t stream, int, const float*); diff --git a/src/turbomind/models/llama/llama_kernels.h b/src/turbomind/models/llama/llama_kernels.h index e88ecd3d1..90d4dc94a 100644 --- a/src/turbomind/models/llama/llama_kernels.h +++ b/src/turbomind/models/llama/llama_kernels.h @@ -62,6 +62,7 @@ void invokeTransposeKVCache(T* key_cache_trans, int max_seq_len, int size_per_head, int head_num, + int head_n_rep, cudaStream_t stream, int quant_policy, const float* kv_scale); diff --git a/src/turbomind/triton_backend/llama/LlamaTritonModel.cc b/src/turbomind/triton_backend/llama/LlamaTritonModel.cc index abecf5b97..bd09f3805 100644 --- a/src/turbomind/triton_backend/llama/LlamaTritonModel.cc +++ b/src/turbomind/triton_backend/llama/LlamaTritonModel.cc @@ -59,6 +59,11 @@ std::shared_ptr AbstractTransformerModel::createLlamaM template void LlamaTritonModel::handleMissingParams() { + if (kv_head_num_ == 0) { + kv_head_num_ = head_num_; + TM_LOG_WARNING("[LlamaTritonModel] `kv_head_num` is not set, default to `head_num` (%d).", (int)kv_head_num_); + } + if (!max_batch_size_) { max_batch_size_ = 32; TM_LOG_WARNING("[LlamaTritonModel] `max_batch_size` is not set, default to %d.", (int)max_batch_size_); @@ -112,6 +117,7 @@ LlamaTritonModel::LlamaTritonModel(size_t tensor_para_size, model_name_ = reader.Get("llama", "model_name"); head_num_ = reader.GetInteger("llama", "head_num"); + kv_head_num_ = reader.GetInteger("llama", "kv_head_num", 0); size_per_head_ = reader.GetInteger("llama", "size_per_head"); inter_size_ = reader.GetInteger("llama", "inter_size"); num_layer_ = reader.GetInteger("llama", "num_layer"); @@ -211,6 +217,7 @@ std::unique_ptr> LlamaTritonModel::createSh ft::FT_CHECK(pipeline_para.world_size_ = pipeline_para_size_); auto llama = std::make_unique>(head_num_, + kv_head_num_, size_per_head_, inter_size_, num_layer_, @@ -283,7 +290,9 @@ void LlamaTritonModel::createSharedWeights(int device_id, int rank) const int tensor_para_rank = rank % tensor_para_size_; const int pipeline_para_rank = rank / tensor_para_size_; ft::FT_CHECK(pipeline_para_size_ == 1 && pipeline_para_rank == 0); - shared_weights_[device_id] = std::make_shared>(head_num_ * size_per_head_, + shared_weights_[device_id] = std::make_shared>(head_num_, + kv_head_num_, + size_per_head_, inter_size_, vocab_size_, num_layer_, @@ -301,16 +310,16 @@ std::string LlamaTritonModel::toString() { std::stringstream ss; ss << "Model: " - << "\nhead_num: " << head_num_ << "\nsize_per_head: " << size_per_head_ << "\ninter_size: " << inter_size_ - << "\nnum_layer: " << num_layer_ << "\nvocab_size: " << vocab_size_ << "\nattn_bias: " << attn_bias_ - << "\nmax_batch_size: " << max_batch_size_ << "\nmax_context_token_num: " << max_context_token_num_ - << "\nsession_len: " << session_len_ << "\nstep_length: " << step_length_ - << "\ncache_max_entry_count: " << cache_max_entry_count_ << "\ncache_chunk_size: " << cache_chunk_size_ - << "\nuse_context_fmha: " << use_context_fmha_ << "\nstart_id: " << start_id_ - << "\ntensor_para_size: " << tensor_para_size_ << "\npipeline_para_size: " << pipeline_para_size_ - << "\nenable_custom_all_reduce: " << enable_custom_all_reduce_ << "\nmodel_name: " << model_name_ - << "\nprefix_cache_len: " << prefix_cache_len_ << "\nmodel_dir: " << model_dir_ - << "\nquant_policy: " << quant_policy_ << std::endl; + << "\nhead_num: " << head_num_ << "\nkv_head_num: " << kv_head_num_ << "\nsize_per_head: " << size_per_head_ + << "\ninter_size: " << inter_size_ << "\nnum_layer: " << num_layer_ << "\nvocab_size: " << vocab_size_ + << "\nattn_bias: " << attn_bias_ << "\nmax_batch_size: " << max_batch_size_ + << "\nmax_context_token_num: " << max_context_token_num_ << "\nsession_len: " << session_len_ + << "\nstep_length: " << step_length_ << "\ncache_max_entry_count: " << cache_max_entry_count_ + << "\ncache_chunk_size: " << cache_chunk_size_ << "\nuse_context_fmha: " << use_context_fmha_ + << "\nstart_id: " << start_id_ << "\ntensor_para_size: " << tensor_para_size_ + << "\npipeline_para_size: " << pipeline_para_size_ << "\nenable_custom_all_reduce: " << enable_custom_all_reduce_ + << "\nmodel_name: " << model_name_ << "\nprefix_cache_len: " << prefix_cache_len_ + << "\nmodel_dir: " << model_dir_ << "\nquant_policy: " << quant_policy_ << std::endl; return ss.str(); } diff --git a/src/turbomind/triton_backend/llama/LlamaTritonModel.h b/src/turbomind/triton_backend/llama/LlamaTritonModel.h index d44cd768b..d0ecbc482 100644 --- a/src/turbomind/triton_backend/llama/LlamaTritonModel.h +++ b/src/turbomind/triton_backend/llama/LlamaTritonModel.h @@ -74,6 +74,7 @@ struct LlamaTritonModel: public AbstractTransformerModel { std::shared_ptr custom_all_reduce_comm = nullptr); size_t head_num_; + size_t kv_head_num_; size_t size_per_head_; size_t inter_size_; size_t num_layer_; From ee9dd3cbe70cb2ed2027980289f10af16902ef26 Mon Sep 17 00:00:00 2001 From: Li Zhang Date: Wed, 19 Jul 2023 18:42:30 +0000 Subject: [PATCH 2/7] fix model conversion --- lmdeploy/serve/turbomind/deploy.py | 38 ++++++++++++++++-------------- 1 file changed, 20 insertions(+), 18 deletions(-) diff --git a/lmdeploy/serve/turbomind/deploy.py b/lmdeploy/serve/turbomind/deploy.py index cbb1d8eea..500c03432 100644 --- a/lmdeploy/serve/turbomind/deploy.py +++ b/lmdeploy/serve/turbomind/deploy.py @@ -95,6 +95,7 @@ def tokenizer_info(model_path: str): def export(model_name: str, num_layer: int, norm_eps: float, + kv_head_num: int, model_params: dict, tokenizer_path: str, out_dir: str, @@ -169,6 +170,7 @@ def save_bin(param: torch.Tensor, name): cfg = dict(llama=dict( model_name=model_name, head_num=head_num, + kv_head_num=kv_head_num, size_per_head=size_per_head, vocab_size=vocab_size, num_layer=num_layer, @@ -186,7 +188,7 @@ def save_bin(param: torch.Tensor, name): step_length=1, cache_max_entry_count=48, cache_chunk_size=1, - use_context_fmha=1, + use_context_fmha=int(kv_head_num == head_num), quant_policy=0, tensor_para_size=tp)) @@ -200,6 +202,15 @@ def save_bin(param: torch.Tensor, name): return True +def merge_qkv(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, tp: int, + dim: int): + + def reshape(x): + return x.view(x.size(0), tp, -1) if dim == 2 else x.view(tp, -1) + + return torch.cat((reshape(q), reshape(k), reshape(v)), dim=-1) + + def deploy_llama(model_name: str, model_path: str, tokenizer_path: str, triton_models_path: str, tp: int): """Deploy a model with huggingface transformers' format. @@ -225,6 +236,8 @@ def deploy_llama(model_name: str, model_path: str, tokenizer_path: str, model_arg = json.load(f) num_layer = model_arg['n_layers'] norm_eps = model_arg['norm_eps'] + head_num = model_arg.get('n_heads', 32) + kv_head_num = model_arg.get('n_kv_heads', head_num) except Exception as e: print(f'get "n_layers" and "norm_eps" from {params_path} failed: {e}') return False @@ -270,7 +283,6 @@ def get_param(_name, _size): else: # bias param = get_param(param_name, [size]) param.data = param_data - elif i == 0: param = get_param(param_name, param_data.size()) param.data = param_data @@ -293,14 +305,14 @@ def get_param(_name, _size): qkv = tuple(map(model_params.pop, _qkv)) except KeyError: break - # concat by output_dims - qkv = torch.stack(qkv, dim=qkv[0].dim() - 1) + # concat by heads + qkv = merge_qkv(*qkv, tp, dim=2 if t == 'weight' else 1) print(f'layers.{i}.attention.w_qkv.{t}', qkv.shape) model_params[f'layers.{i}.attention.w_qkv.{t}'] = qkv assert i == 0 or num_layer == i, f'miss matched layers: {num_layer} vs {i}' - return export(model_name, num_layer, norm_eps, model_params, + return export(model_name, num_layer, norm_eps, kv_head_num, model_params, tokenizer_path, triton_models_path, tp) @@ -318,16 +330,6 @@ def permute(x: torch.Tensor): 1).transpose(1, 2).reshape(dim, 1) -def merge_qkv(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - kv_head_num: int, dim: int): - - def reshape(x): - return x.view(x.size(0), kv_head_num, -1) if dim == 2 else x.view( - kv_head_num, -1) - - return torch.cat((reshape(q), reshape(k), reshape(v)), dim=-1) - - def deploy_hf(model_name: str, model_path: str, tokenizer_path: str, triton_models_path: str, tp: int): """Deploy a model with huggingface transformers' format. @@ -432,10 +434,10 @@ def get_tensor_transposed(name: str): q = permute(q) k = permute(k) if suffix == _qweight: # weight, qweight - qkv = merge_qkv(q, k, v, kv_head_num, dim=2) + qkv = merge_qkv(q, k, v, tp, dim=2) print(suffix, qkv.shape) else: # scales, zeros, bias - qkv = merge_qkv(q, k, v, kv_head_num, dim=1) + qkv = merge_qkv(q, k, v, tp, dim=1) print(suffix, qkv.shape) for k, v in [('w_qkv', qkv), ('wo', o)]: model_params[f'layers.{i}.attention.{k}.{suffix}'] = v @@ -471,7 +473,7 @@ def get_tensor_transposed(name: str): for ft, hf in other: model_params[ft] = get_tensor(hf) - return export(model_name, num_layer, norm_eps, model_params, + return export(model_name, num_layer, norm_eps, kv_head_num, model_params, tokenizer_path, triton_models_path, tp) From 493eb6abf18304e538b144b70caf365c43bb6d6e Mon Sep 17 00:00:00 2001 From: Li Zhang Date: Thu, 20 Jul 2023 04:40:02 +0000 Subject: [PATCH 3/7] fix lint & remove dev log --- ...der_masked_multihead_attention_template.cuh | 18 ++++++++++++------ .../kernels/unfused_attention_kernels.cu | 1 - .../models/llama/LlamaContextAttentionLayer.cc | 1 - .../models/llama/LlamaDecoderLayerWeight.h | 2 +- src/turbomind/models/llama/LlamaWeight.h | 2 +- 5 files changed, 14 insertions(+), 10 deletions(-) diff --git a/src/turbomind/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.cuh b/src/turbomind/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.cuh index 892c31a9b..b6416910f 100644 --- a/src/turbomind/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.cuh +++ b/src/turbomind/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.cuh @@ -80,7 +80,8 @@ namespace mmha { //////////////////////////////////////////////////////////////////////////////////////////////////// template -struct Qk_vec_m_ {}; +struct Qk_vec_m_ { +}; template<> struct Qk_vec_m_ { @@ -180,7 +181,8 @@ struct Qk_vec_k_<__nv_fp8_e4m3, 256> { //////////////////////////////////////////////////////////////////////////////////////////////////// template -struct K_vec_m_ {}; +struct K_vec_m_ { +}; template<> struct K_vec_m_ { @@ -261,7 +263,8 @@ struct K_vec_k_<__nv_fp8_e4m3, 1> { //////////////////////////////////////////////////////////////////////////////////////////////////// template -struct V_vec_m_ {}; +struct V_vec_m_ { +}; template<> struct V_vec_m_ { @@ -341,7 +344,8 @@ struct V_vec_k_<__nv_fp8_e4m3, 16> { #ifdef MMHA_USE_FP32_ACUM_FOR_FMA template -struct Qk_vec_acum_fp32_ {}; +struct Qk_vec_acum_fp32_ { +}; template<> struct Qk_vec_acum_fp32_ { @@ -423,7 +427,8 @@ struct Qk_vec_acum_fp32_ { //////////////////////////////////////////////////////////////////////////////////////////////////// template -struct K_vec_acum_fp32_ {}; +struct K_vec_acum_fp32_ { +}; template<> struct K_vec_acum_fp32_ { @@ -485,7 +490,8 @@ struct K_vec_acum_fp32_ { #ifdef MMHA_USE_FP32_ACUM_FOR_OUT template -struct V_vec_acum_fp32_ {}; +struct V_vec_acum_fp32_ { +}; template<> struct V_vec_acum_fp32_ { diff --git a/src/turbomind/kernels/unfused_attention_kernels.cu b/src/turbomind/kernels/unfused_attention_kernels.cu index 70c2e2e60..9c19df471 100644 --- a/src/turbomind/kernels/unfused_attention_kernels.cu +++ b/src/turbomind/kernels/unfused_attention_kernels.cu @@ -1552,7 +1552,6 @@ void invokeAddFusedQKVBiasTranspose(T* q_buf, const int int8_mode, cudaStream_t stream) { - TM_LOG_ERROR("invokeAddFusedQKVBiasTranspose"); FT_CHECK(rotary_embedding_dim); FT_CHECK_WITH_INFO(int8_mode != 2, "w8a8 not yet implemented with prefix prompt"); // TODO(mseznec) // To implement rotary embeddings, each thread processes two QKV elems: diff --git a/src/turbomind/models/llama/LlamaContextAttentionLayer.cc b/src/turbomind/models/llama/LlamaContextAttentionLayer.cc index 806053c55..cbe829b54 100644 --- a/src/turbomind/models/llama/LlamaContextAttentionLayer.cc +++ b/src/turbomind/models/llama/LlamaContextAttentionLayer.cc @@ -318,7 +318,6 @@ void LlamaContextAttentionLayer::unfusedMultiHeadAttention(T** key_c int quant, const float* kv_scale) { - TM_LOG_ERROR("[LlamaContextAttentionLayer] head_n_rep=%d", (int)head_n_rep_); // key_cache [B, kvH, S[:t+s], D/x, x] -> [B, qH, t+s, D] // val_cache [B, kvH, S[:t+s], D/x, x] -> [B, qH, t+s, D] invokeTransposeKVCache(k_cache_buf_, diff --git a/src/turbomind/models/llama/LlamaDecoderLayerWeight.h b/src/turbomind/models/llama/LlamaDecoderLayerWeight.h index bc0ada5dc..d83ea9879 100644 --- a/src/turbomind/models/llama/LlamaDecoderLayerWeight.h +++ b/src/turbomind/models/llama/LlamaDecoderLayerWeight.h @@ -37,7 +37,7 @@ struct LlamaDecoderLayerWeight { size_t tensor_para_size, size_t tensor_para_rank); ~LlamaDecoderLayerWeight(); - LlamaDecoderLayerWeight(const LlamaDecoderLayerWeight& other) = delete; + LlamaDecoderLayerWeight(const LlamaDecoderLayerWeight& other) = delete; LlamaDecoderLayerWeight& operator=(const LlamaDecoderLayerWeight& other) = delete; void loadModel(std::string dir_path, FtCudaDataType model_file_type); diff --git a/src/turbomind/models/llama/LlamaWeight.h b/src/turbomind/models/llama/LlamaWeight.h index dc31a5164..d6499021a 100644 --- a/src/turbomind/models/llama/LlamaWeight.h +++ b/src/turbomind/models/llama/LlamaWeight.h @@ -42,7 +42,7 @@ struct LlamaWeight { ~LlamaWeight(); - LlamaWeight(const LlamaWeight& other) = delete; + LlamaWeight(const LlamaWeight& other) = delete; LlamaWeight& operator=(const LlamaWeight& other) = delete; void loadModel(std::string dir_path); From 6e283d62926c5c16742a52bc70febdd73a4a4c55 Mon Sep 17 00:00:00 2001 From: Li Zhang Date: Thu, 20 Jul 2023 04:43:46 +0000 Subject: [PATCH 4/7] update news --- README.md | 1 + README_zh-CN.md | 1 + 2 files changed, 2 insertions(+) diff --git a/README.md b/README.md index 7a467ce52..49602ea4f 100644 --- a/README.md +++ b/README.md @@ -13,6 +13,7 @@ ______________________________________________________________________ ## News +\[2023/07\] TurboMind supports Llama-2 70B with GQA \[2023/07\] TurboMind supports tensor-parallel inference of InternLM. ______________________________________________________________________ diff --git a/README_zh-CN.md b/README_zh-CN.md index 0524e9be4..e34133450 100644 --- a/README_zh-CN.md +++ b/README_zh-CN.md @@ -13,6 +13,7 @@ ______________________________________________________________________ ## 更新 +\[2023/07\] TurboMind 支持使用 GQA 的 Llama-2 70B 模型 \[2023/07\] TurboMind 支持 InternLM 的 Tensor Parallel 推理 ______________________________________________________________________ From 5b1621d4e0302a5f60dda3aabf45062e3cce7bee Mon Sep 17 00:00:00 2001 From: Li Zhang Date: Thu, 20 Jul 2023 04:49:06 +0000 Subject: [PATCH 5/7] minor --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 49602ea4f..a901f9152 100644 --- a/README.md +++ b/README.md @@ -13,7 +13,7 @@ ______________________________________________________________________ ## News -\[2023/07\] TurboMind supports Llama-2 70B with GQA +\[2023/07\] TurboMind supports Llama-2 70B with GQA. \[2023/07\] TurboMind supports tensor-parallel inference of InternLM. ______________________________________________________________________ From 46420dca9f506839e0c1099d56a5d5f4b691a0c6 Mon Sep 17 00:00:00 2001 From: Li Zhang Date: Thu, 20 Jul 2023 11:29:55 +0000 Subject: [PATCH 6/7] fix allocation size --- src/turbomind/models/llama/LlamaContextAttentionLayer.cc | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/turbomind/models/llama/LlamaContextAttentionLayer.cc b/src/turbomind/models/llama/LlamaContextAttentionLayer.cc index cbe829b54..7f44ea215 100644 --- a/src/turbomind/models/llama/LlamaContextAttentionLayer.cc +++ b/src/turbomind/models/llama/LlamaContextAttentionLayer.cc @@ -58,9 +58,10 @@ void LlamaContextAttentionLayer::allocateBuffer(size_t batch_size, } } else { + // kv heads are repeated for unfused attention k_cache_buf_ = (T*)allocator_->reMalloc( - k_cache_buf_, 2 * sizeof(T) * batch_size * local_kv_head_num_ * max_k_len * size_per_head_, true); - v_cache_buf_ = k_cache_buf_ + batch_size * local_kv_head_num_ * max_k_len * size_per_head_; + k_cache_buf_, 2 * sizeof(T) * batch_size * local_head_num_ * max_k_len * size_per_head_, true); + v_cache_buf_ = k_cache_buf_ + batch_size * local_head_num_ * max_k_len * size_per_head_; qk_buf_ = (T*)allocator_->reMalloc(qk_buf_, sizeof(T) * batch_size * local_head_num_ * max_q_len * max_k_len, true); From 277ac1fd9e4bf26faba6afc6bab23de3cd6e228d Mon Sep 17 00:00:00 2001 From: Li Zhang Date: Thu, 20 Jul 2023 14:03:26 +0000 Subject: [PATCH 7/7] fix split_dim for w_qkv.bias --- lmdeploy/serve/turbomind/deploy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lmdeploy/serve/turbomind/deploy.py b/lmdeploy/serve/turbomind/deploy.py index 500c03432..8ebea7244 100644 --- a/lmdeploy/serve/turbomind/deploy.py +++ b/lmdeploy/serve/turbomind/deploy.py @@ -139,7 +139,7 @@ def save_bin(param: torch.Tensor, name): if key == 'w1': inter_size = param_data.shape[-1] elif key == 'w_qkv': - split_dim = 1 + split_dim = -2 elif key in ['w2', 'wo']: if ext in ['scales', 'zeros', 'bias']: copy = True