Skip to content

Commit

Permalink
Add bf16 inference for llm model (#387)
Browse files Browse the repository at this point in the history
Co-authored-by: baishihao <baishihao@sensetime.com>
Co-authored-by: hiworldwzj <30762946+hiworldwzj@users.noreply.github.com>
  • Loading branch information
3 people committed Apr 10, 2024
1 parent cf15435 commit 15a050a
Show file tree
Hide file tree
Showing 26 changed files with 83 additions and 64 deletions.
28 changes: 20 additions & 8 deletions lightllm/common/basemodel/basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@ def __init__(self, kvargs):
self.max_seq_length = kvargs.get("max_seq_length", 1024 * 5)
self.return_all_prompt_logprobs = kvargs.get("return_all_prompt_logprobs", False)
self.use_dynamic_prompt_cache = kvargs.get("use_dynamic_prompt_cache", False)

self.data_type = kvargs.get("data_type", "float16")

self._init_datatype()
self._init_config()
self._verify_must()
self._verify_params()
Expand Down Expand Up @@ -80,16 +82,16 @@ def _verify_params(self):

def _init_weights(self):
self.pre_post_weight = self.pre_and_post_weight_class(
self.tp_rank_, self.world_size_, torch.float16, network_config=self.config, mode=self.mode
self.tp_rank_, self.world_size_, self.data_type, network_config=self.config, mode=self.mode
)
self.trans_layers_weight = [
self.transformer_weight_class(
i, self.tp_rank_, self.world_size_, torch.float16, network_config=self.config, mode=self.mode
i, self.tp_rank_, self.world_size_, self.data_type, network_config=self.config, mode=self.mode
)
for i in range(self.config["n_layer"])
]
load_hf_weights(
"fp16",
self.data_type,
weight_dir=self.weight_dir_,
pre_post_layer=self.pre_post_weight,
transformer_layer_list=self.trans_layers_weight,
Expand All @@ -103,7 +105,7 @@ def _init_mem_manager(self):
assert self.config["num_attention_heads"] % self.world_size_ == 0
self.mem_manager = MemoryManager(
self.max_total_token_num,
dtype=torch.float16,
dtype=self.data_type,
head_num=self.config["num_attention_heads"] // self.world_size_,
head_dim=self.config["n_embed"] // self.config["num_attention_heads"],
layer_num=self.config["n_layer"],
Expand Down Expand Up @@ -137,6 +139,16 @@ def _init_some_value(self):
self.vocab_size = self.config["vocab_size"]
return

def _init_datatype(self):
if self.data_type in ["fp16", "float16"]:
self.data_type = torch.float16
elif self.data_type in ["bf16", "bfloat16"]:
self.data_type = torch.bfloat16
elif self.data_type in ["fp32", "float32"]:
self.data_type =torch.float32
else:
raise ValueError(f"Unsupport datatype {self.data_type}!")

def _init_custom(self):
pass

Expand Down Expand Up @@ -223,7 +235,7 @@ def _prefill(
infer_state.mem_index = alloc_mem
infer_state.kv_buffer = torch.empty(
(input_ids.shape[0], self.tp_k_head_num_ + self.tp_v_head_num_, self.head_dim_),
dtype=torch.float16,
dtype=self.data_type,
device="cuda",
)

Expand Down Expand Up @@ -279,7 +291,7 @@ def _decode(
infer_state.mem_index = alloc_mem
infer_state.kv_buffer = torch.empty(
(batch_size, self.tp_k_head_num_ + self.tp_v_head_num_, self.head_dim_),
dtype=torch.float16,
dtype=self.data_type,
device="cuda",
)
copy_kv_index_to_req(self.req_manager.req_to_token_indexs, b_req_idx, b_seq_len, infer_state.mem_index)
Expand Down Expand Up @@ -341,7 +353,7 @@ def splitfuse_forward(
infer_state.mem_index = alloc_mem
infer_state.kv_buffer = torch.empty(
(alloc_size, self.tp_k_head_num_ + self.tp_v_head_num_, self.head_dim_),
dtype=torch.float16,
dtype=self.data_type,
device="cuda",
)

Expand Down
3 changes: 2 additions & 1 deletion lightllm/common/basemodel/layer_weights/hf_load_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ def load_func(file_, use_safetensors=False, pre_post_layer=None, transformer_lay


def load_hf_weights(data_type, weight_dir, pre_post_layer=None, transformer_layer_list=None, weight_dict=None):
data_type = torch.float16 if data_type == 'fp16' else torch.float32
if isinstance(data_type, str):
data_type = torch.float16 if data_type == 'fp16' else torch.float32
if pre_post_layer is not None:
assert pre_post_layer.data_type_ == data_type, "type is not right"
if transformer_layer_list is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ def load_hf_weights(self, weights):
(self.tp_rank_ + 1), :])
if 'lm_head.weight' in weights:
# print(weights['lm_head.weight'].shape)
self.lm_head_weight_ = nn.functional.normalize(weights['lm_head.weight'].to(
torch.float16).cuda())[split_vob_size * self.tp_rank_:split_vob_size * (self.tp_rank_ + 1), :]
self.lm_head_weight_ = self._cuda(
nn.functional.normalize(weights['lm_head.weight'])[split_vob_size * self.tp_rank_:split_vob_size * (self.tp_rank_ + 1), :])
if 'model.norm.weight' in weights:
self.final_norm_weight_ = self._cuda(weights['model.norm.weight'])

Expand Down
4 changes: 2 additions & 2 deletions lightllm/models/bloom/layer_infer/post_layer_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def soft_max(self, data):

def token_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight: BloomPreAndPostLayerWeight, return_logics=False):
batch_size = infer_state.batch_size
last_input = torch.empty((batch_size, self.embed_dim_), device=input_embdings.device, dtype=torch.float16)
last_input = torch.empty((batch_size, self.embed_dim_), device=input_embdings.device, dtype=input_embdings.dtype)
if infer_state.is_prefill:
last_index = torch.cumsum(infer_state.b_seq_len - infer_state.b_ready_cache_len, dim=0, dtype=torch.long) - 1
last_input[:, :] = input_embdings[last_index, :]
Expand All @@ -44,7 +44,7 @@ def token_forward(self, input_embdings, infer_state: InferStateInfo, layer_weigh
if self.world_size_ == 1:
gather_data = logic_batch
else:
gather_data = torch.empty((self.vocab_size_, batch_size), device=logic_batch.device, dtype=torch.float16)
gather_data = torch.empty((self.vocab_size_, batch_size), device=logic_batch.device, dtype=input_embdings.dtype)
split_size = self.vocab_size_ // self.world_size_
dist.all_gather([gather_data[i * split_size: (i + 1) * split_size, :]
for i in range(self.world_size_)], logic_batch, group=None, async_op=False)
Expand Down
3 changes: 2 additions & 1 deletion lightllm/models/bloom/layer_weights/hf_load_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@


def load_hf_weights(data_type, weight_dir, pre_post_layer=None, transformer_layer_list=None, weight_dict=None):
data_type = torch.float16 if data_type == 'fp16' else torch.float32
if isinstance(data_type, str):
data_type = torch.float16 if data_type == 'fp16' else torch.float32
if pre_post_layer is not None:
assert pre_post_layer.data_type_ == data_type, "type is not right"
if transformer_layer_list is not None:
Expand Down
6 changes: 3 additions & 3 deletions lightllm/models/bloom/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,13 @@ def _reset_num_key_value_heads(self):
return

def _init_weights(self):
self.pre_post_weight = self.pre_and_post_weight_class(self.tp_rank_, self.world_size_, torch.float16, network_config=self.config, mode=self.mode)
self.pre_post_weight = self.pre_and_post_weight_class(self.tp_rank_, self.world_size_, self.data_type, network_config=self.config, mode=self.mode)
self.trans_layers_weight = [
self.transformer_weight_class(i, self.tp_rank_, self.world_size_, torch.float16, network_config=self.config, mode=self.mode)
self.transformer_weight_class(i, self.tp_rank_, self.world_size_, self.data_type, network_config=self.config, mode=self.mode)
for i in range(self.config["n_layer"])
]
load_hf_weights(
"fp16",
self.data_type,
weight_dir=self.weight_dir_,
pre_post_layer=self.pre_post_weight,
transformer_layer_list=self.trans_layers_weight,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def _fwd_kernel_token_att2(
v_value = tl.load(V + v_offs + v_loc[:, None] * stride_vbs, mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, other=0.0)
acc += tl.sum(p_value[:, None] * v_value, 0)

acc = acc.to(tl.float16)
acc = acc.to(Out.dtype.element_ty)
off_o = cur_batch * stride_obs + cur_head * stride_oh + offs_d * stride_od
out_ptrs = Out + off_o
tl.store(out_ptrs, acc)
Expand Down
4 changes: 2 additions & 2 deletions lightllm/models/chatglm2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,6 @@ def _init_to_get_rotary(self, base=10000):
t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor
freqs = torch.outer(t, inv_freq)

self._cos_cached = torch.cos(freqs).to(torch.float16).cuda()
self._sin_cached = torch.sin(freqs).to(torch.float16).cuda()
self._cos_cached = torch.cos(freqs).to(self.data_type).cuda()
self._sin_cached = torch.sin(freqs).to(self.data_type).cuda()
return
6 changes: 3 additions & 3 deletions lightllm/models/gemma_2b/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def _init_custom(self):

def _init_mem_manager(self):
self.mem_manager = MemoryManager(self.max_total_token_num,
dtype=torch.float16,
dtype=self.data_type,
head_num=self.config["num_key_value_heads"], # [SYM] always == 1
head_dim=self.config["hidden_size"] // self.config["num_attention_heads"],
layer_num=self.config["num_hidden_layers"])
Expand Down Expand Up @@ -73,7 +73,7 @@ def _init_to_get_rotary(self, default_base=10000):
t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor
freqs = torch.outer(t, inv_freq)

self._cos_cached = torch.cos(freqs).to(torch.float16).cuda()
self._sin_cached = torch.sin(freqs).to(torch.float16).cuda()
self._cos_cached = torch.cos(freqs).to(self.data_type).cuda()
self._sin_cached = torch.sin(freqs).to(self.data_type).cuda()
return

2 changes: 1 addition & 1 deletion lightllm/models/gemma_2b/triton_kernel/gelu_and_mul.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def _gelu_and_mul_kernel(
).to(tl.float32)

gate = gelu(gate)
gate = gate.to(tl.float16)
gate = gate.to(input_ptr.dtype.element_ty)

tl.store(
input_ptr + res_offsets,
Expand Down
6 changes: 3 additions & 3 deletions lightllm/models/llama/layer_infer/post_layer_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def _slice_get_last_input(self, input_embdings, infer_state: LlamaInferStateInfo
if infer_state.is_splitfuse:
# for SplitFuse
batch_size = infer_state.batch_size
last_input = torch.empty((batch_size, self.embed_dim_), device=input_embdings.device, dtype=torch.float16)
last_input = torch.empty((batch_size, self.embed_dim_), device=input_embdings.device, dtype=input_embdings.dtype)
tmp_ = torch.cat(
[
torch.ones(infer_state.decode_req_num, dtype=torch.int32, device="cuda"),
Expand All @@ -44,7 +44,7 @@ def _slice_get_last_input(self, input_embdings, infer_state: LlamaInferStateInfo

if not infer_state.is_splitfuse and infer_state.is_prefill and not infer_state.return_all_prompt_logprobs:
batch_size = infer_state.batch_size
last_input = torch.empty((batch_size, self.embed_dim_), device=input_embdings.device, dtype=torch.float16)
last_input = torch.empty((batch_size, self.embed_dim_), device=input_embdings.device, dtype=input_embdings.dtype)
last_index = (
torch.cumsum(infer_state.b_seq_len - infer_state.b_ready_cache_len, dim=0, dtype=torch.long) - 1
)
Expand Down Expand Up @@ -81,7 +81,7 @@ def token_forward(
if self.world_size_ == 1:
gather_data = logic_batch
else:
gather_data = torch.empty((self.vocab_size_, token_num), device=logic_batch.device, dtype=torch.float16)
gather_data = torch.empty((self.vocab_size_, token_num), device=logic_batch.device, dtype=input_embdings.dtype)
split_indexes = np.linspace(0, self.vocab_size_, self.world_size_ + 1, dtype=np.int64)
dist.all_gather(
[gather_data[split_indexes[i] : split_indexes[i + 1], :] for i in range(self.world_size_)],
Expand Down
30 changes: 15 additions & 15 deletions lightllm/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def _verify_params(self):

def _init_mem_manager(self):
self.mem_manager = select_mem_manager_class(self.mode)(self.max_total_token_num,
dtype=torch.float16,
dtype=self.data_type,
head_num=self.config["num_key_value_heads"] // self.world_size_,
head_dim=self.config["hidden_size"] // self.config["num_attention_heads"],
layer_num=self.config["num_hidden_layers"])
Expand All @@ -74,21 +74,21 @@ def _init_custom(self):
return

def _init_weights(self):
self.pre_post_weight = self.pre_and_post_weight_class(self.tp_rank_, self.world_size_, torch.float16, network_config=self.config, mode=self.mode)
self.pre_post_weight = self.pre_and_post_weight_class(self.tp_rank_, self.world_size_, self.data_type, network_config=self.config, mode=self.mode)
self.trans_layers_weight = [
self.transformer_weight_class(i, self.tp_rank_, self.world_size_, torch.float16, network_config=self.config, mode=self.mode)
self.transformer_weight_class(i, self.tp_rank_, self.world_size_, self.data_type, network_config=self.config, mode=self.mode)
for i in range(self.config["n_layer"])
]
if self.load_way == 'HF':
load_hf_weights(
"fp16",
self.data_type,
weight_dir=self.weight_dir_,
pre_post_layer=self.pre_post_weight,
transformer_layer_list=self.trans_layers_weight,
weight_dict=self.weight_dict)
else:
load_ds_weights(
"fp16",
self.data_type,
weight_dir=self.weight_dir_,
pre_post_layer=self.pre_post_weight,
transformer_layer_list=self.trans_layers_weight,
Expand Down Expand Up @@ -132,8 +132,8 @@ def _init_to_get_rotary(self, default_base=10000):
t = torch.arange(max_seq_len + 1024 * 128, device="cpu", dtype=torch.float32) / rope_scaling_factor
freqs = torch.outer(t, inv_freq)

self._cos_cached = torch.cos(freqs).to(torch.float16).cuda()
self._sin_cached = torch.sin(freqs).to(torch.float16).cuda()
self._cos_cached = torch.cos(freqs).to(self.data_type).cuda()
self._sin_cached = torch.sin(freqs).to(self.data_type).cuda()
return

def _init_to_get_dynamic_ntk_rotary(self):
Expand All @@ -145,22 +145,22 @@ def _init_to_get_dynamic_ntk_rotary(self):
else:
scaling_factor = self.config.get("rope_scaling", {}).get("factor", 1.0)
max_seq_len = max(self.max_seq_length, max_position_embeddings)
self._cos_cached = torch.zeros((max_seq_len, partial_head_dim // 2), dtype=torch.float16, device="cuda")
self._sin_cached = torch.zeros((max_seq_len, partial_head_dim // 2), dtype=torch.float16, device="cuda")
self._cos_cached = torch.zeros((max_seq_len, partial_head_dim // 2), dtype=self.data_type, device="cuda")
self._sin_cached = torch.zeros((max_seq_len, partial_head_dim // 2), dtype=self.data_type, device="cuda")

inv_freq = 1.0 / (base ** (torch.arange(0, partial_head_dim, 2, device="cpu", dtype=torch.float32) / partial_head_dim))
t = torch.arange(max_position_embeddings, device="cpu", dtype=torch.float32)
freqs = torch.outer(t, inv_freq)
self._cos_cached[0:max_position_embeddings, :] = torch.cos(freqs).to(torch.float16).cuda()
self._sin_cached[0:max_position_embeddings, :] = torch.sin(freqs).to(torch.float16).cuda()
self._cos_cached[0:max_position_embeddings, :] = torch.cos(freqs).to(self.data_type).cuda()
self._sin_cached[0:max_position_embeddings, :] = torch.sin(freqs).to(self.data_type).cuda()

for seq_loc_index in range(max_position_embeddings, max_seq_len, 1):
new_base = base * ((scaling_factor * (seq_loc_index + 1) / max_position_embeddings) -(scaling_factor - 1)) ** (partial_head_dim / (partial_head_dim - 2))
inv_freq = 1.0 / (new_base ** (torch.arange(0, partial_head_dim, 2, device="cpu", dtype=torch.float32) / partial_head_dim))
t = torch.tensor([seq_loc_index,], device="cpu", dtype=torch.float32)
freqs = torch.outer(t, inv_freq)
self._cos_cached[seq_loc_index:seq_loc_index + 1, :] = torch.cos(freqs).to(torch.float16).cuda()
self._sin_cached[seq_loc_index:seq_loc_index + 1, :] = torch.sin(freqs).to(torch.float16).cuda()
self._cos_cached[seq_loc_index:seq_loc_index + 1, :] = torch.cos(freqs).to(self.data_type).cuda()
self._sin_cached[seq_loc_index:seq_loc_index + 1, :] = torch.sin(freqs).to(self.data_type).cuda()
return

def _init_to_get_yarn_rotary(self):
Expand Down Expand Up @@ -194,8 +194,8 @@ def _init_to_get_yarn_rotary(self):
freqs = torch.einsum("i,j->ij", t, inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self._cos_cached = emb.cos().to(torch.float16).cuda() * mscale
self._sin_cached = emb.sin().to(torch.float16).cuda() * mscale
self._cos_cached = emb.cos().to(self.data_type).cuda() * mscale
self._sin_cached = emb.sin().to(self.data_type).cuda() * mscale

return

Expand Down
2 changes: 1 addition & 1 deletion lightllm/models/llama/triton_kernel/rmsnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def _rms_norm_fwd_fused(
x_hat = x * rstd
y = x_hat * w
# Write output
tl.store(Y + cols, y.to(tl.float16), mask=mask)
tl.store(Y + cols, y.to(Y.dtype.element_ty), mask=mask)


def rmsnorm_forward(x, weight, eps):
Expand Down
2 changes: 1 addition & 1 deletion lightllm/models/llama/triton_kernel/silu_and_mul.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def _silu_and_mul_kernel(
).to(tl.float32)

gate = gate / (1 + tl.exp(-gate))
gate = gate.to(tl.float16)
gate = gate.to(input_ptr.dtype.element_ty)

tl.store(
input_ptr + res_offsets,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -301,8 +301,8 @@ def _fwd_kernel_int8(
vs_ptrs + kv_loc[:, None] * stride_vsbs, mask=(start_n + offs_n)[:, None] < cur_batch_seq_len, other=0.0
)

p = p.to(tl.float16)
acc += tl.dot(p, v.to(tl.float16) * v_scale)
p = p.to(V.dtype.element_ty)
acc += tl.dot(p, v.to(V.dtype.element_ty) * v_scale)

# update m_i and l_i
l_i = l_i_new
Expand Down
6 changes: 3 additions & 3 deletions lightllm/models/mistral/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def _init_custom(self):

def _init_mem_manager(self):
self.mem_manager = MemoryManager(self.max_total_token_num, # [SYM] should be sliding window?
dtype=torch.float16,
dtype=self.data_type,
head_num=self.config["num_key_value_heads"] // self.world_size_,
head_dim=self.config["hidden_size"] // self.config["num_attention_heads"],
layer_num=self.config["num_hidden_layers"],
Expand Down Expand Up @@ -79,7 +79,7 @@ def _init_to_get_rotary(self, default_base=10000):
t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor
freqs = torch.outer(t, inv_freq)

self._cos_cached = torch.cos(freqs).to(torch.float16).cuda()
self._sin_cached = torch.sin(freqs).to(torch.float16).cuda()
self._cos_cached = torch.cos(freqs).to(self.data_type).cuda()
self._sin_cached = torch.sin(freqs).to(self.data_type).cuda()
return

Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def _fwd_kernel_token_att2(
v_value = tl.load(V + v_offs + v_loc[:, None] * stride_vbs, mask=(start_n + offs_n[:, None] + cur_batch_start_index) < cur_batch_seq_len, other=0.0) # [1, D] + [64, 1] = [64, D]
acc += tl.sum(p_value[:, None] * v_value, 0) # [64, 1] * [64, D] = [64, D] -> [D]

acc = acc.to(tl.float16)
acc = acc.to(Out.dtype.element_ty)
off_o = cur_batch * stride_obs + cur_head * stride_oh + offs_d * stride_od
out_ptrs = Out + off_o
tl.store(out_ptrs, acc)
Expand Down
Loading

0 comments on commit 15a050a

Please sign in to comment.