Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 29 additions & 3 deletions paddlenlp/experimental/transformers/deepseek_v2/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
)
from paddlenlp.utils.log import logger

__all__ = ["DeepseekV2ForCausalLMBlockInferenceModel"]
__all__ = ["DeepseekV2ForCausalLMBlockInferenceModel", "DeepseekVLV2ForCausalLMBlockInferenceModel"]

Check warning on line 58 in paddlenlp/experimental/transformers/deepseek_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/deepseek_v2/modeling.py#L58

Added line #L58 was not covered by tests


class DeepseekScalingRotaryEmbedding(nn.Layer):
Expand Down Expand Up @@ -1316,8 +1316,14 @@
kwargs["max_input_length"] = self.max_seq_len
kwargs["block_size"] = self.block_size

inputs_embeds = self.embed_tokens(ids_remove_padding)

if inputs_embeds is None:
inputs_embeds = self.embed_tokens(ids_remove_padding)

Check warning on line 1320 in paddlenlp/experimental/transformers/deepseek_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/deepseek_v2/modeling.py#L1319-L1320

Added lines #L1319 - L1320 were not covered by tests
else:
assert len(inputs_embeds.shape) == 3

Check warning on line 1322 in paddlenlp/experimental/transformers/deepseek_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/deepseek_v2/modeling.py#L1322

Added line #L1322 was not covered by tests
# This is the case in the image-to-text model
# In the prefill phase, the language model is first fed with inputs_embeds instead of input_ids
# but in decoder phase, the language model is fed with input_ids just like normal text-to-text model.
inputs_embeds = inputs_embeds.reshape([-1, inputs_embeds.shape[2]])

Check warning on line 1326 in paddlenlp/experimental/transformers/deepseek_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/deepseek_v2/modeling.py#L1326

Added line #L1326 was not covered by tests
with dy2st_nocheck_guard_context():
hidden_states, _ = self.transformer_block(
input_ids=input_ids,
Expand Down Expand Up @@ -1573,6 +1579,7 @@
def prepare_inputs_for_generation(self, **kwargs):
# only last token for inputs_ids if cache is defined in kwargs
input_ids = kwargs["input_ids"]
inputs_embeds = kwargs.get("inputs_embeds", None)

Check warning on line 1582 in paddlenlp/experimental/transformers/deepseek_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/deepseek_v2/modeling.py#L1582

Added line #L1582 was not covered by tests
src_mask = kwargs.get("src_mask", None)
block_tables = kwargs.get("block_tables", None)

Expand All @@ -1593,6 +1600,7 @@

model_inputs = {
"input_ids": input_ids,
"inputs_embeds": inputs_embeds,
"src_mask": src_mask,
"rope_emb": None,
"pre_caches": pre_caches,
Expand All @@ -1613,6 +1621,7 @@
def forward(
self,
input_ids,
inputs_embeds=None,
src_mask=None,
pre_caches=None,
caches=None,
Expand All @@ -1630,6 +1639,7 @@
):
outputs = self.deepseek_v2(
input_ids,
inputs_embeds=inputs_embeds,
src_mask=src_mask,
caches=caches,
rope_emb=None,
Expand Down Expand Up @@ -1805,3 +1815,19 @@
)

return logits, hidden_states


class DeepseekVLV2ForCausalLMBlockInferenceModel(DeepseekV2ForCausalLMBlockInferenceModel):
def __init__(self, config: DeepseekV2Config):
super().__init__(config, base_model_prefix="language.model")

Check warning on line 1822 in paddlenlp/experimental/transformers/deepseek_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/deepseek_v2/modeling.py#L1820-L1822

Added lines #L1820 - L1822 were not covered by tests

def get_input_embeddings(self):
return self.deepseek_v2.embed_tokens

Check warning on line 1825 in paddlenlp/experimental/transformers/deepseek_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/deepseek_v2/modeling.py#L1824-L1825

Added lines #L1824 - L1825 were not covered by tests

@paddle.no_grad()
def set_state_dict(self, state_dict):
if "language.lm_head.weight" in state_dict:
self.lm_head.weight.set_value(

Check warning on line 1830 in paddlenlp/experimental/transformers/deepseek_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/deepseek_v2/modeling.py#L1827-L1830

Added lines #L1827 - L1830 were not covered by tests
paddle.to_tensor(state_dict["language.lm_head.weight"]).cast(self.lm_head.weight.dtype)
)
self.deepseek_v2.set_state_dict({k: state_dict[k] for k in state_dict.keys()})

Check warning on line 1833 in paddlenlp/experimental/transformers/deepseek_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/deepseek_v2/modeling.py#L1833

Added line #L1833 was not covered by tests