Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[utc] fix loading local model in taskflow #4505

Merged
merged 5 commits into from
Jan 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions applications/zero_shot_text_classification/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,8 @@ python run_train.py \
--disable_tqdm True \
--metric_for_best_model macro_f1 \
--load_best_model_at_end True \
--save_total_limit 1
--save_total_limit 1 \
--save_plm
```

如果在GPU环境中使用,可以指定gpus参数进行多卡训练:
Expand Down Expand Up @@ -143,7 +144,8 @@ python -u -m paddle.distributed.launch --gpus "0,1" run_train.py \
--disable_tqdm True \
--metric_for_best_model macro_f1 \
--load_best_model_at_end True \
--save_total_limit 1
--save_total_limit 1 \
--save_plm
```

该示例代码中由于设置了参数 `--do_eval`,因此在训练完会自动进行评估。
Expand Down Expand Up @@ -204,7 +206,7 @@ python run_eval.py \
>>> from pprint import pprint
>>> from paddlenlp import Taskflow
>>> schema = ["病情诊断", "治疗方案", "病因分析", "指标解读", "就医建议", "疾病表述", "后果表述", "注意事项", "功效作用", "医疗费用", "其他"]
>>> my_cls = Taskflow("zero_shot_text_classification", schema=schema, task_path='./checkpoint/model_best', precision="fp16")
>>> my_cls = Taskflow("zero_shot_text_classification", schema=schema, task_path='./checkpoint/model_best/plm', precision="fp16")
>>> pprint(my_cls("中性粒细胞比率偏低"))
```

Expand All @@ -221,7 +223,7 @@ from paddlenlp import SimpleServer, Taskflow
schema = ["病情诊断", "治疗方案", "病因分析", "指标解读", "就医建议"]
utc = Taskflow("zero_shot_text_classification",
schema=schema,
task_path="../../checkpoint/model_best/",
task_path="../../checkpoint/model_best/plm",
precision="fp32")
app = SimpleServer()
app.register_taskflow("taskflow/utc", utc)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,15 @@ schema = ["病情诊断", "治疗方案", "病因分析", "指标解读", "就

```python
# Default task_path
utc = Taskflow("zero_shot_text_classification", task_path="../../checkpoint/model_best/", schema=schema)
utc = Taskflow("zero_shot_text_classification", task_path="../../checkpoint/model_best/plm", schema=schema)
```

#### 多卡服务化预测
PaddleNLP SimpleServing 支持多卡负载均衡预测,主要在服务化注册的时候,注册两个Taskflow的task即可,下面是示例代码

```python
utc1 = Taskflow("zero_shot_text_classification", task_path="../../checkpoint/model_best/", schema=schema)
utc2 = Taskflow("zero_shot_text_classification", task_path="../../checkpoint/model_best/", schema=schema)
utc1 = Taskflow("zero_shot_text_classification", task_path="../../checkpoint/model_best/plm", schema=schema)
utc2 = Taskflow("zero_shot_text_classification", task_path="../../checkpoint/model_best/plm", schema=schema)
service.register_taskflow("taskflow/utc", [utc1, utc2])
```

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# The schema changed to your defined schema
schema = ["病情诊断", "治疗方案", "病因分析", "指标解读", "就医建议", "疾病表述", "后果表述", "注意事项", "功效作用", "医疗费用", "其他"]
# The task path changed to your best model path
utc = Taskflow("zero_shot_text_classification", task_path="../../checkpoint/model_best/", schema=schema)
utc = Taskflow("zero_shot_text_classification", task_path="../../checkpoint/model_best/plm", schema=schema)
# If you want to define the finetuned utc service
app = SimpleServer()
app.register_taskflow("taskflow/utc", utc)
30 changes: 26 additions & 4 deletions paddlenlp/taskflow/zero_shot_text_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,30 @@ class ZeroShotTextClassificationTask(Task):
"special_tokens_map": "special_tokens_map.json",
"tokenizer_config": "tokenizer_config.json",
}
resource_files_urls = {
"utc-large": {
"model_state": [
"https://bj.bcebos.com/paddlenlp/taskflow/zero_shot_text_classification/utc-large/model_state.pdparams",
"71eb9a732c743a513b84ca048dc4945b",
],
"config": [
"https://bj.bcebos.com/paddlenlp/taskflow/zero_shot_text_classification/utc-large/config.json",
"9496be2cc99f7e6adf29280320274142",
],
"vocab_file": [
"https://bj.bcebos.com/paddlenlp/taskflow/zero_text_classification/utc-large/vocab.txt",
"afc01b5680a53525df5afd7518b42b48",
],
"special_tokens_map": [
"https://bj.bcebos.com/paddlenlp/taskflow/zero_text_classification/utc-large/special_tokens_map.json",
"2458e2131219fc1f84a6e4843ae07008",
],
"tokenizer_config": [
"https://bj.bcebos.com/paddlenlp/taskflow/zero_text_classification/utc-large/tokenizer_config.json",
"dcb0f3257830c0eb1f2de47f2d86f89a",
],
},
}

def __init__(self, task: str, model: str = "utc-large", schema: list = None, **kwargs):
super().__init__(task=task, model=model, **kwargs)
Expand All @@ -64,6 +88,7 @@ def __init__(self, task: str, model: str = "utc-large", schema: list = None, **k
self._pred_threshold = kwargs.get("pred_threshold", 0.5)
self._num_workers = kwargs.get("num_workers", 0)

self._check_task_files()
self._construct_tokenizer()
self._check_predictor_type()
self._get_inference_model()
Expand Down Expand Up @@ -102,10 +127,7 @@ def _construct_model(self, model):
"""
Construct the inference model for the predictor.
"""
if self.from_hf_hub:
model_instance = UTC.from_pretrained(self._task_path, from_hf_hub=self.from_hf_hub)
else:
model_instance = UTC.from_pretrained(model)
model_instance = UTC.from_pretrained(self._task_path, from_hf_hub=self.from_hf_hub)
self._model = model_instance
self._model.eval()

Expand Down