Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix LLAMA Inference Bug #6865

Merged
merged 1 commit into from
Aug 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion llm/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,7 @@ def __init__(

self.cache_kvs = [
paddle.zeros(shape, dtype=dtype)
for shape in self.model.get_cache_kvs_shape(self.model.config, config.max_batch_size)
for shape in self.model.get_cache_kvs_shape(self.model.config, config.max_batch_size, config.max_length)
]
self.pre_ids = paddle.full([config.max_batch_size, config.max_length], -1, dtype="int64")
if "chatglm" in self.architectures:
Expand Down
25 changes: 13 additions & 12 deletions paddlenlp/experimental/transformers/chatglm/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,18 +294,19 @@
new_cache = [None]
hidden_states = self.input_layernorm(hidden_states)

hidden_states, new_cache = self.transformer_block(
input_ids,
hidden_states,
cum_offsets=cum_offsets,
padding_offset=padding_offset,
attn_mask=paddle.cast(attention_mask, dtype=hidden_states.dtype),
caches=cache_kvs,
rotary_embs=paddle.cast(rotary_embeds, "float32"),
rotary_emb_dims=2 if self.config.position_encoding_2d else 1,
seq_lens=seq_lens,
time_step=time_step,
)
with paddle.fluid.framework._stride_in_no_check_dy2st_diff():
hidden_states, new_cache = self.transformer_block(

Check warning on line 298 in paddlenlp/experimental/transformers/chatglm/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/chatglm/modeling.py#L297-L298

Added lines #L297 - L298 were not covered by tests
input_ids,
hidden_states,
cum_offsets=cum_offsets,
padding_offset=padding_offset,
attn_mask=paddle.cast(attention_mask, dtype=hidden_states.dtype),
caches=cache_kvs,
rotary_embs=paddle.cast(rotary_embeds, "float32"),
rotary_emb_dims=2 if self.config.position_encoding_2d else 1,
seq_lens=seq_lens,
time_step=time_step,
)
return (hidden_states, new_cache)

@paddle.no_grad()
Expand Down
10 changes: 5 additions & 5 deletions paddlenlp/experimental/transformers/generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@
if cache is None:
# encoder's generation
model_kwargs["tgt_ids"] = paddle.where(just_decoder, model_kwargs["tgt_ids"], next_tokens)
if self.model.config["position_encoding_2d"] and self.model.config.position_encoding_2d is True:
if self.config["position_encoding_2d"] and self.config.position_encoding_2d is True:

Check warning on line 165 in paddlenlp/experimental/transformers/generation_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/generation_utils.py#L165

Added line #L165 was not covered by tests
tgt_pos = model_kwargs["tgt_pos"]
new_position_id = tgt_pos[:, 0, :].clone()
new_block_id = tgt_pos[:, 1, :].clone()
Expand All @@ -182,7 +182,7 @@
)
else:
model_kwargs["tgt_ids"] = next_tokens
if self.model.config["position_encoding_2d"] and self.model.config.position_encoding_2d is True:
if self.config["position_encoding_2d"] and self.config.position_encoding_2d is True:

Check warning on line 185 in paddlenlp/experimental/transformers/generation_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/generation_utils.py#L185

Added line #L185 was not covered by tests
tgt_pos = model_kwargs["tgt_pos"]
new_position_id = tgt_pos[:, 0, :].clone()
new_block_id = tgt_pos[:, 1, :].clone()
Expand Down Expand Up @@ -261,9 +261,9 @@
# compute next_tokens, use paddle.top_p_sampling
logits = logits / temperature

_, next_tokens = top_p_sampling(probs, top_p)
_, next_tokens = top_p_sampling(probs, top_p, -1)

Check warning on line 264 in paddlenlp/experimental/transformers/generation_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/generation_utils.py#L264

Added line #L264 was not covered by tests

if self.model.config.tensor_parallel_degree > 1:
if self.config.tensor_parallel_degree > 1:

Check warning on line 266 in paddlenlp/experimental/transformers/generation_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/generation_utils.py#L266

Added line #L266 was not covered by tests
paddle.distributed.broadcast(next_tokens, 0)

model_kwargs = self.update_model_kwargs_for_generation(
Expand All @@ -275,7 +275,7 @@
batch_idx,
step_idx_ori,
"real_time_save.temp_ids",
self.model.config.tensor_parallel_rank,
self.config.tensor_parallel_rank,
)

return next_tokens, model_kwargs
Expand Down
35 changes: 18 additions & 17 deletions paddlenlp/experimental/transformers/llama/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from paddlenlp.experimental.transformers.generation_utils import (
GenerationInferenceModel,
)
from paddlenlp.transformers import LlamaConfig, LlamaForCausalLM, LlamaPretrainedModel
from paddlenlp.transformers import LlamaConfig, LlamaPretrainedModel

Check warning on line 28 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L28

Added line #L28 was not covered by tests
from paddlenlp.transformers.llama.modeling import LlamaLMHead
from paddlenlp.transformers.model_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
Expand Down Expand Up @@ -200,18 +200,19 @@

new_rope = fused_get_rotary_embedding(input_ids, position_ids, self.head_dim_shape_tensor, 0, True)

hidden_states, _ = self.transformer_block(
input_ids,
hidden_states,
cum_offsets=cum_offsets,
padding_offset=padding_offset,
attn_mask=paddle.cast(attention_mask, dtype=hidden_states.dtype),
caches=cache_kvs,
seq_lens=seq_lens,
rotary_embs=new_rope,
rotary_emb_dims=1,
time_step=paddle.increment(paddle.shape(attention_mask)[-1], -1) if is_decoder else None,
)
with paddle.fluid.framework._stride_in_no_check_dy2st_diff():
hidden_states, _ = self.transformer_block(

Check warning on line 204 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L203-L204

Added lines #L203 - L204 were not covered by tests
input_ids,
hidden_states,
cum_offsets=cum_offsets,
padding_offset=padding_offset,
attn_mask=paddle.cast(attention_mask, dtype=hidden_states.dtype),
caches=cache_kvs,
seq_lens=seq_lens,
rotary_embs=new_rope,
rotary_emb_dims=1,
time_step=paddle.increment(paddle.shape(attention_mask)[-1], -1) if is_decoder else None,
)
hidden_states = self.norm(hidden_states)

if output_hidden_states:
Expand Down Expand Up @@ -289,7 +290,7 @@
)


class LlamaForCausalLMInferenceModel(GenerationInferenceModel, LlamaForCausalLM):
class LlamaForCausalLMInferenceModel(GenerationInferenceModel, LlamaPretrainedModel):

Check warning on line 293 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L293

Added line #L293 was not covered by tests
"""
Dynamic Batching for LLaMA Model with pretraining tasks on top.
"""
Expand All @@ -298,7 +299,7 @@

def __init__(self, config):
super().__init__(config)
self.model = LlamaInferenceModel(config)
self.llama = LlamaInferenceModel(config)

Check warning on line 302 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L302

Added line #L302 was not covered by tests
self.lm_head = LlamaLMHead(config)

@classmethod
Expand Down Expand Up @@ -384,7 +385,7 @@
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

outputs = self.model(
outputs = self.llama(

Check warning on line 388 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L388

Added line #L388 was not covered by tests
input_ids,
position_ids=position_ids,
attention_mask=attention_mask,
Expand Down Expand Up @@ -430,4 +431,4 @@
def set_state_dict(self, state_dict):
if "lm_head.weight" in state_dict:
self.lm_head.weight.set_value(state_dict["lm_head.weight"])
self.model.set_state_dict({k: state_dict[k] for k in state_dict.keys()})
self.llama.set_state_dict({k: state_dict[k] for k in state_dict.keys()})

Check warning on line 434 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L434

Added line #L434 was not covered by tests