Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Paddle Inference]support miniGPT4's second part dy2st #6905

Merged
merged 11 commits into from
Sep 7, 2023

Conversation

zhoutianzi666
Copy link
Contributor

@zhoutianzi666 zhoutianzi666 commented Sep 4, 2023

PR types

PR changes

Description

  • 在 modeling.py 中新加了一个类LlamaForminiGPT4InferenceModel,

    • 此类中添加了函数 generate_text_with_image_features,
    • 重写了to_static函数,使得能够导出静态图。
  • 修改了部分代码,使得paddlenlp/experimental/transformers/generation_utils.py的generate函数支持input_ids为None,inputs_embeds不None的情形。

  • 用户可用这个文件 PaddleNLP/llm/export_model.py,用这个命令

  • python3.8 export_model.py --model_name_or_path /zhoukangkang/2023-06-06minigpt/whole_part/llama-13b-fp16/ --output_path /zhoukangkang/2023-06-06minigpt/whole_part/miniGPT4-second-part_kaiyuan_fp16 --dtype float16 --inference_model --model_prefix=llama --model_type=llama-img2txt --max_batch_size=2 > out.txt

  • 导出miniGPT4中的语言模型的静态图

@paddle-bot
Copy link

paddle-bot bot commented Sep 4, 2023

Thanks for your contribution!

@codecov
Copy link

codecov bot commented Sep 4, 2023

Codecov Report

Merging #6905 (f03d084) into develop (e183825) will decrease coverage by 0.04%.
The diff coverage is 0.00%.

@@             Coverage Diff             @@
##           develop    #6905      +/-   ##
===========================================
- Coverage    59.87%   59.84%   -0.04%     
===========================================
  Files          552      552              
  Lines        81452    81499      +47     
===========================================
  Hits         48772    48772              
- Misses       32680    32727      +47     
Files Changed Coverage Δ
...enlp/experimental/transformers/generation_utils.py 0.00% <0.00%> (ø)
...dlenlp/experimental/transformers/llama/modeling.py 0.00% <0.00%> (ø)

Comment on lines 160 to 164
first_embeds = self.llama.embed_tokens(first_input_ids)
second_embeds = self.llama.embed_tokens(second_input_ids)
image_features = paddle.cast(image_features, dtype=first_embeds.dtype)
inputs_embeds = paddle.concat([first_embeds, image_features, second_embeds], axis=1)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

generation_utils.py 理论上是和模型无关的,这里不应该融入过多的与模型逻辑有关的代码,所以这块不应该将 llama 相关的逻辑代码嵌入到这边来,建议将这些逻辑写到llama/modeling.py 文件里面去。

Copy link
Contributor Author

@zhoutianzi666 zhoutianzi666 Sep 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

generation_utils.py 理论上是和模型无关的,这里不应该融入过多的与模型逻辑有关的代码,所以这块不应该将 llama 相关的逻辑代码嵌入到这边来,建议将这些逻辑写到llama/modeling.py 文件里面去。

done,感谢,已经在modeling.py中添加了一个新类LlamaForminiGPT4InferenceModel

@@ -159,12 +159,14 @@ def forward(
cache_kvs=None,
seq_len_encoder=None,
seq_len_decoder=None,
# past_key_values is useless,as it is replaced by kwargs["cache"], so confusion.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这边是为了和 hf 对齐,所以优先建议使用 past_key_values 参数来,所以建议这里的 comment 删掉。

input_ids,
eos_token_id,
input_ids=None,
inputs_embeds=None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

你这个新加的 inputs_embeds 参数需要放在 temperature 后面,不然大部分调用 sample 函数的代码处都会报错,所以这个改动是很危险的。

你可以全局搜索一下:.sample(就会发现调用这个函数的地方都是通过 args 的方式来传值

Comment on lines 342 to 353
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is None and inputs_embeds is None:
raise ValueError("You have to specify either input_ids or inputs_embeds")

# genereate a fake input_ids according to inputs_embeds.
if input_ids is None and inputs_embeds is not None:
input_ids = self.prepare_input_ids_for_generation(1, inputs_embeds)
if inputs_embeds is not None:
batch, seq_len, hidden_dim = inputs_embeds.shape
inputs_embeds = inputs_embeds.reshape([batch * seq_len, hidden_dim])
model_kwargs["inputs_embeds"] = inputs_embeds
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

目前 llama 模型是支持传入 inputs_embeds 参数,所以你这边就直接塞到 model_inputs 里面去就行了,让他传入到模型里面去。

@@ -83,6 +82,122 @@ def to_static(self, output_path: str, config: dict):
model = paddle.jit.to_static(self.generate, input_spec=input_spec)
paddle.jit.save(model, output_path)

# this function make generate_with_image_features to static inference model.
def generate_with_image_features_to_static(self, output_path: str, config: dict):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

还是建议直接用to_static 这个方法,避免重复性的工作,调整起来也是比较简单:在 llama for causallm 类里面重写一些 to_static 的方法,避免将这种逻辑上升到通用函数中去。

], # cache_kvs
]

model = paddle.jit.to_static(self.generate_with_image_features, input_spec=input_spec)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.self.generate_with_image_features 可以通过 config.get("generate_method", self.generate) 这种方式来,这样外面也可以实现自动化配置。

Comment on lines 69 to 74
predictor.model.generate_with_image_features_to_static(
get_infer_model_path(export_args.output_path, predictor_args.model_prefix), {"dtype": predictor_args.dtype}
)
predictor.model.config.save_pretrained(export_args.output_path)
predictor.tokenizer.save_pretrained(export_args.output_path)
generate_rank_mapping(os.path.join(export_args.output_path, "rank_mapping.csv"))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个文件其实是可以和 export_model.py 融合到一起的。

@wj-Mcat
Copy link
Contributor

wj-Mcat commented Sep 4, 2023

此外,你 update develop branch 把,我看你代码的版本有些 delay 了。

@zhoutianzi666 zhoutianzi666 force-pushed the support_kaiyuan_minigpt4 branch 2 times, most recently from da0a1de to bee52b6 Compare September 5, 2023 07:34
Copy link
Contributor

@wj-Mcat wj-Mcat left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

你这里的调整非常好,可是有如下小问题,

此外,等 #6923 合入之后,建议使用 model_type 来代替 llm_for_img2txt

llm/predictor.py Outdated
if predictor_args.llm_for_img2txt:
# we use llama for img2txt.
from paddlenlp.experimental.transformers import (
LlamaForminiGPT4InferenceModel as LlamaInferenceModel,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
LlamaForminiGPT4InferenceModel as LlamaInferenceModel,
LlamaForMiniGPT4InferenceModel as LlamaInferenceModel,

Comment on lines 240 to 251
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is None and inputs_embeds is None:
raise ValueError("You have to specify either input_ids or inputs_embeds")

# genereate a fake input_ids according to inputs_embeds.
if input_ids is None and inputs_embeds is not None:
input_ids = self.prepare_input_ids_for_generation(self.config.bos_token_id, inputs_embeds)
if inputs_embeds is not None:
batch, seq_len, hidden_dim = inputs_embeds.shape
inputs_embeds = inputs_embeds.reshape([batch * seq_len, hidden_dim])
model_kwargs["inputs_embeds"] = inputs_embeds
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这块的逻辑是需要迁移到模型的 forward 里面去的,而不是在 generation_utils 里面,具体可参考:https://github.com/PaddlePaddle/PaddleNLP/blob/develop/paddlenlp/transformers/llama/modeling.py#L1189

在 experimental/transformers/llama/modeling.py 下面目前是没有对应的 checking,所以建议你将这部分的代码挪过去一下,非常感谢。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这块的逻辑是需要迁移到模型的 forward 里面去的,而不是在 generation_utils 里面,具体可参考:https://github.com/PaddlePaddle/PaddleNLP/blob/develop/paddlenlp/transformers/llama/modeling.py#L1189

在 experimental/transformers/llama/modeling.py 下面目前是没有对应的 checking,所以建议你将这部分的代码挪过去一下,非常感谢。

已改,辛苦review

Copy link
Contributor

@wj-Mcat wj-Mcat left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@wj-Mcat wj-Mcat merged commit 294df07 into PaddlePaddle:develop Sep 7, 2023
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants