Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add infer predict function for PET #2690

Merged
merged 16 commits into from Jul 1, 2022
Merged

add infer predict function for PET #2690

merged 16 commits into from Jul 1, 2022

Conversation

lastrei
Copy link
Contributor

@lastrei lastrei commented Jun 30, 2022

PR types

Others

PR changes

Models

Description

add infer predict function for PET

@wawltor wawltor self-requested a review June 30, 2022 07:56
@wawltor wawltor self-assigned this Jun 30, 2022
@lastrei
Copy link
Contributor Author

lastrei commented Jun 30, 2022

infer predict example

model_dir is the path of infer model exist
!python predict.py --model_dir './inference_model/' --task_name 'tnews'

@lastrei lastrei marked this pull request as draft June 30, 2022 12:55
@lastrei lastrei marked this pull request as ready for review July 1, 2022 00:27
from paddlenlp.data import Stack, Tuple, Pad
from paddlenlp.datasets import load_dataset
from paddlenlp.utils.log import logger
from functools import partial
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个import提前一下,这个属于PE8标准

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix

tokenizer = ppnlp.transformers.ErnieTokenizer.from_pretrained('ernie-1.0')

batchify_test_fn = lambda samples, fn=Tuple(
Pad(axis=0, pad_val=tokenizer.pad_token_id), # src_ids
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pad这块可以设置一下dtype="int64", 在windows不会出现不兼容的问题

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix

parser.add_argument("--pattern_id", default=0, type=int, help="pattern id of pet")
args = parser.parse_args()


Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

空行可以去掉

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix

self.batch_size = batch_size

model_file = model_dir + "/inference.pdmodel"
params_file = model_dir + "/inference.pdiparams"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这块是不是可以os.path.join这种,在windows的目录层次是 \ 符号

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix

self.output_handle = self.predictor.get_output_handle(
self.predictor.get_output_names()[0])

if args.benchmark:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

benchmark这块的代码可以去掉,PET这块不需要benchmark监测

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix

parser.add_argument('--enable_mkldnn', default=False, type=eval, choices=[True, False],
help='Enable to use mkldnn to speed up when using cpu.')

parser.add_argument("--benchmark", type=eval, default=False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

benchmark的args可以去掉

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix

batchify_fn=batchify_test_fn,
trans_fn=trans_test_func)
p = Predictor(args.model_dir)
y = p.predict(tokenizer, test_data_loader, label_norm_dict)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

整体需要一个文档介绍一下怎么使用,例如TRT的安装,可以拷贝一下其他的文档;同时predict.py怎么使用

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add README.md for how to use predict.py

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add README.md for how to use predict.py

OK,Thank for your contribution!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add README.md for how to use predict.py

OK,Thank for your contribution!

Thanks for your patience while i make these changes

@lastrei lastrei marked this pull request as draft July 1, 2022 04:38
@lastrei lastrei marked this pull request as ready for review July 1, 2022 05:18
Copy link
Collaborator

@wawltor wawltor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@wawltor wawltor merged commit b0a6061 into PaddlePaddle:develop Jul 1, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants