Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support Llama-2 with GQA #147

Merged
merged 8 commits into from Jul 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 2 additions & 1 deletion README.md
Expand Up @@ -13,8 +13,9 @@ ______________________________________________________________________

## News 🎉

- \[2023/07\] TurboMind supports Llama-2 70B with GQA.
- \[2023/07\] TurboMind supports Llama-2 7B/13B.
- \[2023/07\] TurboMind supports tensor-parallel inference of InternLM.
- \[2023/07\] TurboMind supports llama2 7b/13b.

______________________________________________________________________

Expand Down
3 changes: 2 additions & 1 deletion README_zh-CN.md
Expand Up @@ -13,8 +13,9 @@ ______________________________________________________________________

## 更新 🎉

- \[2023/07\] TurboMind 支持使用 GQA 的 Llama-2 70B 模型
- \[2023/07\] TurboMind 支持 Llama-2 7B/13B 模型
- \[2023/07\] TurboMind 支持 InternLM 的 Tensor Parallel 推理
- \[2023/07\] TurboMind 支持 Llama2 7b/13b 模型

______________________________________________________________________

Expand Down
2 changes: 1 addition & 1 deletion examples/cpp/llama/CMakeLists.txt
Expand Up @@ -3,6 +3,6 @@
add_executable(llama_triton_example llama_triton_example.cc)
target_link_libraries(llama_triton_example PUBLIC -lcublas -lcublasLt -lcudart
LlamaTritonBackend TransformerTritonBackend mpi_utils nccl_utils
nvtx_utils word_list glog)
nvtx_utils word_list)

install(TARGETS llama_triton_example DESTINATION ${CMAKE_INSTALL_PREFIX}/bin)
39 changes: 28 additions & 11 deletions lmdeploy/serve/turbomind/deploy.py
Expand Up @@ -95,6 +95,7 @@ def tokenizer_info(model_path: str):
def export(model_name: str,
num_layer: int,
norm_eps: float,
kv_head_num: int,
model_params: dict,
tokenizer_path: str,
out_dir: str,
Expand Down Expand Up @@ -133,10 +134,12 @@ def save_bin(param: torch.Tensor, name):
if key == 'w_qkv' and ext == 'bias':
attn_bias = True
copy = False
if key in ['w1', 'w3', 'w_qkv']:
if key in ['w1', 'w3']:
split_dim = -1
if key == 'w1':
inter_size = param_data.shape[-1]
elif key == 'w_qkv':
split_dim = -2
elif key in ['w2', 'wo']:
if ext in ['scales', 'zeros', 'bias']:
copy = True
Expand Down Expand Up @@ -167,6 +170,7 @@ def save_bin(param: torch.Tensor, name):
cfg = dict(llama=dict(
model_name=model_name,
head_num=head_num,
kv_head_num=kv_head_num,
size_per_head=size_per_head,
vocab_size=vocab_size,
num_layer=num_layer,
Expand All @@ -184,7 +188,7 @@ def save_bin(param: torch.Tensor, name):
step_length=1,
cache_max_entry_count=48,
cache_chunk_size=1,
use_context_fmha=1,
use_context_fmha=int(kv_head_num == head_num),
quant_policy=0,
tensor_para_size=tp))

Expand All @@ -198,6 +202,15 @@ def save_bin(param: torch.Tensor, name):
return True


def merge_qkv(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, tp: int,
dim: int):

def reshape(x):
return x.view(x.size(0), tp, -1) if dim == 2 else x.view(tp, -1)

return torch.cat((reshape(q), reshape(k), reshape(v)), dim=-1)


def deploy_llama(model_name: str, model_path: str, tokenizer_path: str,
triton_models_path: str, tp: int):
"""Deploy a model with huggingface transformers' format.
Expand All @@ -223,6 +236,8 @@ def deploy_llama(model_name: str, model_path: str, tokenizer_path: str,
model_arg = json.load(f)
num_layer = model_arg['n_layers']
norm_eps = model_arg['norm_eps']
head_num = model_arg.get('n_heads', 32)
kv_head_num = model_arg.get('n_kv_heads', head_num)
except Exception as e:
print(f'get "n_layers" and "norm_eps" from {params_path} failed: {e}')
return False
Expand Down Expand Up @@ -268,7 +283,6 @@ def get_param(_name, _size):
else: # bias
param = get_param(param_name, [size])
param.data = param_data

elif i == 0:
param = get_param(param_name, param_data.size())
param.data = param_data
Expand All @@ -291,14 +305,14 @@ def get_param(_name, _size):
qkv = tuple(map(model_params.pop, _qkv))
except KeyError:
break
# concat by output_dims
qkv = torch.stack(qkv, dim=qkv[0].dim() - 1)
# concat by heads
qkv = merge_qkv(*qkv, tp, dim=2 if t == 'weight' else 1)
print(f'layers.{i}.attention.w_qkv.{t}', qkv.shape)
model_params[f'layers.{i}.attention.w_qkv.{t}'] = qkv

assert i == 0 or num_layer == i, f'miss matched layers: {num_layer} vs {i}'

return export(model_name, num_layer, norm_eps, model_params,
return export(model_name, num_layer, norm_eps, kv_head_num, model_params,
tokenizer_path, triton_models_path, tp)


Expand Down Expand Up @@ -349,6 +363,10 @@ def deploy_hf(model_name: str, model_path: str, tokenizer_path: str,
model_arg = json.load(f)
num_layer = model_arg['num_hidden_layers']
norm_eps = model_arg['rms_norm_eps']
if 'num_key_value_heads' in model_arg:
kv_head_num = model_arg['num_key_value_heads']
else:
kv_head_num = model_arg['num_attention_heads']
except Exception as e:
print(f'get "num_hidden_layers" and "rms_norm_eps" from '
f'{params_path} failed: {e}')
Expand Down Expand Up @@ -416,11 +434,10 @@ def get_tensor_transposed(name: str):
q = permute(q)
k = permute(k)
if suffix == _qweight: # weight, qweight
# insert a dimension for splitting heads later
qkv = torch.stack((q, k, v), dim=1)
qkv = merge_qkv(q, k, v, tp, dim=2)
print(suffix, qkv.shape)
else: # scales, zeros, bias
qkv = torch.stack((q.squeeze(), k.squeeze(), v.squeeze()),
dim=0).squeeze(dim=-1)
qkv = merge_qkv(q, k, v, tp, dim=1)
print(suffix, qkv.shape)
for k, v in [('w_qkv', qkv), ('wo', o)]:
model_params[f'layers.{i}.attention.{k}.{suffix}'] = v
Expand Down Expand Up @@ -456,7 +473,7 @@ def get_tensor_transposed(name: str):
for ft, hf in other:
model_params[ft] = get_tensor(hf)

return export(model_name, num_layer, norm_eps, model_params,
return export(model_name, num_layer, norm_eps, kv_head_num, model_params,
tokenizer_path, triton_models_path, tp)


Expand Down
1 change: 1 addition & 0 deletions src/turbomind/kernels/decoder_masked_multihead_attention.h
Expand Up @@ -132,6 +132,7 @@ struct Multihead_attention_params: public Multihead_attention_params_base<T> {
T** v_cache_per_sample = nullptr;
size_t kv_cache_per_sample_offset = 0;
bool k_cache_interleaved = true;
int num_kv_heads = 0;
};

template<class T>
Expand Down
Expand Up @@ -677,7 +677,7 @@ inline __device__ float4 hmma_fp32(const uint2& a, uint32_t b)
" {%7, %7, %7, %7}; \n"

: "=f"(c.x), "=f"(c.y), "=f"(c.z), "=f"(c.w)
: "r"(a.x) "r"(a.y), "r"(b), "f"(zero));
: "r"(a.x), "r"(a.y), "r"(b), "f"(zero));
return c;
}

Expand Down Expand Up @@ -1349,16 +1349,18 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
if (params.finished != nullptr && params.finished[bi] == true) {
return;
}
// The beam idx
const int beami = bi % params.beam_width;
// The "beam-aware" batch idx
const int bbi = bi / params.beam_width;

// The head.
const int hi = blockIdx.x;
// Combine the batch and the head indices.
const int bhi = bi * params.num_heads + hi;
// Combine the "beam-aware" batch idx and the head indices.
const int bbhi = bbi * params.beam_width * params.num_heads + hi;

const int head_n_rep = params.num_heads / params.num_kv_heads;

const int kvhi = hi / head_n_rep; // heads in the same group collapse to the same kv head

const bool group_leader = hi % head_n_rep == 0; // only group leader writes to kv cache

// The thread in the block.
const int tidx = threadIdx.x;

Expand All @@ -1369,8 +1371,6 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>

float qk = 0.0F;

int qkv_base_offset = (params.stride == 0) ? bhi * Dh : bi * params.stride + hi * Dh;

const size_t bi_seq_len_offset = bi * params.memory_max_len;

const int tlength = params.length_per_sample[bi] + params.max_prefix_prompt_length;
Expand All @@ -1380,10 +1380,16 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
// First QK_VECS_PER_WARP load Q and K + the bias values for the current timestep.
const bool is_masked = tidx >= QK_VECS_PER_WARP;

const int q_base_offset = bi * params.stride + hi * Dh;
const int k_base_offset = bi * params.stride + kvhi * Dh;

// The offset in the Q and K buffer also accounts for the batch.
int qk_offset = qkv_base_offset + tidx * QK_VEC_SIZE;
const int q_offset = q_base_offset + tidx * QK_VEC_SIZE;
const int k_offset = k_base_offset + tidx * QK_VEC_SIZE;

// The offset in the bias buffer.
int qk_bias_offset = hi * Dh + tidx * QK_VEC_SIZE;
const int q_bias_offset = hi * Dh + tidx * QK_VEC_SIZE;
const int k_bias_offset = kvhi * Dh + tidx * QK_VEC_SIZE;

// past kv quant param
const float k_scale = params.attention_k_scale;
Expand All @@ -1393,31 +1399,30 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
Qk_vec_k q;
zero(q);
if (!is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh)) {
q = vec_conversion<Qk_vec_k, Qk_vec_m>(*reinterpret_cast<const Qk_vec_m*>(&params.q[qk_offset]));
q = vec_conversion<Qk_vec_k, Qk_vec_m>(*reinterpret_cast<const Qk_vec_m*>(&params.q[q_offset]));
}

Qk_vec_k k;
zero(k);
{
k = !is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) ?
vec_conversion<Qk_vec_k, Qk_vec_m>(*reinterpret_cast<const Qk_vec_m*>(&params.k[qk_offset])) :
vec_conversion<Qk_vec_k, Qk_vec_m>(*reinterpret_cast<const Qk_vec_m*>(&params.k[k_offset])) :
k;
}

// Trigger the loads from the Q and K bias buffers.
Qk_vec_k q_bias;
zero(q_bias);
q_bias =
(!is_masked && Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) && params.q_bias != nullptr ?
vec_conversion<Qk_vec_k, Qk_vec_m>(*reinterpret_cast<const Qk_vec_m*>(&params.q_bias[qk_bias_offset])) :
q_bias;
q_bias = (!is_masked && Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) && params.q_bias != nullptr ?
vec_conversion<Qk_vec_k, Qk_vec_m>(*reinterpret_cast<const Qk_vec_m*>(&params.q_bias[q_bias_offset])) :
q_bias;

Qk_vec_k k_bias;
zero(k_bias);
if (handle_kv) {
k_bias =
!is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) && params.k_bias != nullptr ?
vec_conversion<Qk_vec_k, Qk_vec_m>(*reinterpret_cast<const Qk_vec_m*>(&params.k_bias[qk_bias_offset])) :
vec_conversion<Qk_vec_k, Qk_vec_m>(*reinterpret_cast<const Qk_vec_m*>(&params.k_bias[k_bias_offset])) :
k_bias;
}

Expand Down Expand Up @@ -1454,7 +1459,7 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
// The position of the thread in that 16B chunk.
int ci = tidx % QK_VECS_IN_16B * QK_VEC_SIZE;

if (handle_kv) {
if (handle_kv && group_leader) {
// Trigger the stores to global memory.
if (Dh == Dh_MAX || co < Dh / QK_ELTS_IN_16B) {
if (!params.k_cache_per_sample) {
Expand All @@ -1476,12 +1481,12 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
else {
int offset;
if (params.k_cache_interleaved) {
offset = params.kv_cache_per_sample_offset + hi * params.memory_max_len * Dh
offset = params.kv_cache_per_sample_offset + kvhi * params.memory_max_len * Dh
+ co * params.memory_max_len * QK_ELTS_IN_16B + tlength_circ * QK_ELTS_IN_16B + ci;
}
else {
offset = params.kv_cache_per_sample_offset + hi * params.memory_max_len * Dh + tlength_circ * Dh
+ co * QK_ELTS_IN_16B + ci;
offset = params.kv_cache_per_sample_offset + kvhi * params.memory_max_len * Dh
+ tlength_circ * Dh + co * QK_ELTS_IN_16B + ci;
}

if (not QUANT_POLICY) {
Expand Down Expand Up @@ -1577,7 +1582,7 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>

if (not QUANT_POLICY) {
k_cache_batch = params.k_cache_per_sample ? (params.k_cache_per_sample[bi] + params.kv_cache_per_sample_offset
+ hi * params.memory_max_len * Dh + ki) :
+ kvhi * params.memory_max_len * Dh + ki) :
&params.k_cache[bhi * params.memory_max_len * Dh + ki];
// Base pointer for the beam's batch, before offsetting with indirection buffer
// T* k_cache_batch = &params.k_cache[bbhi * params.memory_max_len * Dh + ki];
Expand All @@ -1586,7 +1591,7 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
// convert k_cache_per_sample to int8
if (params.k_cache_per_sample) {
int8_t* ptr = reinterpret_cast<int8_t*>(params.k_cache_per_sample[bi]);
k_cache_batch_int8 = ptr + params.kv_cache_per_sample_offset + hi * params.memory_max_len * Dh + ki;
k_cache_batch_int8 = ptr + params.kv_cache_per_sample_offset + kvhi * params.memory_max_len * Dh + ki;
}
else {
int8_t* ptr = reinterpret_cast<int8_t*>(params.k_cache);
Expand Down Expand Up @@ -1765,7 +1770,7 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
if (not QUANT_POLICY) {

v_cache = params.v_cache_per_sample ? (params.v_cache_per_sample[bi] + params.kv_cache_per_sample_offset
+ hi * params.memory_max_len * Dh + vi) :
+ kvhi * params.memory_max_len * Dh + vi) :
&params.v_cache[bhi * params.memory_max_len * Dh + vi];
// Base pointer for the beam's batch, before offsetting with indirection buffer
// T* v_cache_batch = &params.v_cache[bbhi * params.memory_max_len * Dh + vi];
Expand All @@ -1774,7 +1779,7 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
else if (QUANT_POLICY == 4) {
if (params.v_cache_per_sample) {
int8_t* ptr = reinterpret_cast<int8_t*>(params.v_cache_per_sample[bi]);
v_cache_int8 = ptr + params.kv_cache_per_sample_offset + hi * params.memory_max_len * Dh + vi;
v_cache_int8 = ptr + params.kv_cache_per_sample_offset + kvhi * params.memory_max_len * Dh + vi;
}
else {
int8_t* ptr = reinterpret_cast<int8_t*>(params.v_cache);
Expand All @@ -1787,22 +1792,6 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
// The number of values processed per iteration of the loop.
constexpr int V_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_VALUE;

// One group of threads computes the product(s) for the current timestep.
V_vec_k v_bias;
zero(v_bias);
// if( vo == params.timestep % V_PER_ITER ) {
if (Dh == Dh_MAX || vi < Dh) {
if (handle_kv) {
if (vo == tlength % V_PER_ITER) {
// Trigger the loads from the V bias buffer.
if (params.v_bias != nullptr) {
v_bias = vec_conversion<V_vec_k, V_vec_m>(
*reinterpret_cast<const V_vec_m*>(&params.v_bias[hi * Dh + vi]));
}
}
}
}

// From previous, before values, step
// Also make sure the logits are in shared memory.
__syncthreads();
Expand Down Expand Up @@ -1924,14 +1913,18 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
V_vec_k v;

// Trigger the loads from the V buffer.
const auto v_offset = qkv_base_offset + vi;
v = vec_conversion<V_vec_k, V_vec_m>(*reinterpret_cast<const V_vec_m*>(&params.v[v_offset]));
const auto v_offset = k_base_offset + vi;

v = vec_conversion<V_vec_k, V_vec_m>(*reinterpret_cast<const V_vec_m*>(&params.v[v_offset]));

// Trigger the loads from the V bias buffer.
// V_vec v_bias = *reinterpret_cast<const V_vec*>(&params.v_bias[hi*Dh + vi]);
if (params.v_bias != nullptr) {
V_vec_k v_bias = *reinterpret_cast<const V_vec_k*>(&params.v_bias[kvhi * Dh + vi]);
v = add(v, v_bias);
}

// Compute the V values with bias.
if (handle_kv) {
v = add(v, v_bias);
// Store the V values to cache
if (handle_kv && group_leader) {

// Store the values with bias back to global memory in the cache for V.
//*reinterpret_cast<V_vec_k*>(&v_cache[params.timestep*Dh]) = v;
Expand Down