Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add hierarchical text classification #2501

Merged
merged 6 commits into from
Jun 27, 2022
Merged

add hierarchical text classification #2501

merged 6 commits into from
Jun 27, 2022

Conversation

lugimzzz
Copy link
Contributor

@lugimzzz lugimzzz commented Jun 13, 2022

PR types

Others

PR changes

Others

Description

新增层次分类算法,包含训练,预测,静态图导出,裁剪等算法,支持paddle serving, triton, onnxruntime多种部署方式。

以下是本项目主要代码结构及说明:

hierarchical_classification/
├── deploy # 部署
│   └── onnxruntime # 导出ONNX模型并基于ONNXRuntime部署
│   │   ├── infer.py # ONNXRuntime推理部署示例
│   │   ├── predictor.py
│   │   └── README.md # 使用说明
│   ├── serving # 基于Paddle Serving 部署
│   │   ├──config.yml # 层次分类任务启动服务端的配置文件
│   │   ├──rpc_client.py # 层次分类任务发送pipeline预测请求的脚本
│   │   └──service.py # 层次分类任务启动服务端的脚本
│   └── triton # 基于Triton server 部署
│       ├── README.md # 使用说明
│       ├── seqcls_grpc_client.py # 客户端测试代码
│       └── models # 部署模型
│           ├── seqcls
│           │   └── config.pbtxt
│           ├── seqcls_model
│           │   └──config.pbtxt
│           ├── seqcls_postprocess
│           │   ├── 1
│           │   │   └── model.py
│           │   └── config.pbtxt
│           └── tokenizer
│               ├── 1
│               │   └── model.py
│               └── config.pbtxt
├── train.py # 训练评估脚本
├── predict.py # 预测脚本
├── export_model.py # 动态图参数导出静态图参数脚本
├── utils.py # 工具函数脚本
├── metric.py # metric脚本
├── prune.py # 裁剪脚本
├── prune_trainer.py # 裁剪trainer脚本
├── prune_config.py # 裁剪训练参数配置
├── requirements.txt # 环境依赖
└── README.md # 使用说明

@wawltor wawltor self-requested a review June 14, 2022 06:38
@wawltor wawltor self-assigned this Jun 14, 2022
@wawltor wawltor added the enhancement New feature or request label Jun 14, 2022
@@ -0,0 +1,251 @@
# 多标签层次分类任务
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. pre-commit 没有安装,code style check没有过

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已安装

## 层次分类任务介绍

多标签层次分类任务指自然语言处理任务中,每个样本具有多个标签标记,并且标签集合中存在预定义的树状结构或有向无环图结构,多标签层次分类需要充分考虑标签集之间的层次结构关系来预测层次化预测结果。在现实场景中,大量的数据如新闻分类、专利分类、学术论文分类等标签集合存在预定义的层次化结构,需要利用算法为文本自动标注更细粒度和更准确的标签。本项目中采用通用多标签层次分类算法将每个结点的标签路径视为一个多类标签,使用单个分类器进行决策。

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这块能不能用一些图片或者例子来示意层次分类了?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

新增图片表示:
如下图所示(R代表根节点),层次分类任务中标签层次结构分为两类,一类为树状结构,另一类为有向无环图(DAG)结构。有向无环图结构与树状结构区别在于,有向无环图中的节点可能存在不止一个父节点。层次分类问题可以被视为一个多标签问题,以左图树状结构为例,如果一个样本属于类别1.2.1,样本也天然地同时属于类别1和类别1.2两个样本标签。本项目采用通用多标签层次分类算法,将每个结点的标签路径视为一个多分类标签,使用单个多标签分类器进行决策。以上面的例子为例,该样本包含三个标签:1、1->1.2、1->1.2->1.2.1。
截屏2022-06-15 下午2 59 24


## 层次分类任务介绍

多标签层次分类任务指自然语言处理任务中,每个样本具有多个标签标记,并且标签集合中存在预定义的树状结构或有向无环图结构,多标签层次分类需要充分考虑标签集之间的层次结构关系来预测层次化预测结果。在现实场景中,大量的数据如新闻分类、专利分类、学术论文分类等标签集合存在预定义的层次化结构,需要利用算法为文本自动标注更细粒度和更准确的标签。本项目中采用通用多标签层次分类算法将每个结点的标签路径视为一个多类标签,使用单个分类器进行决策。
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

多类标签->多分类标签

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

## 模型微调

我们以层次分类公开数据集WOS(Web of Science)为示例,在训练集上进行模型训练,并在开发集上验证,开发集中选出的最优的模型在测试集上进行评估。WOS数据集是一个层次文本分类数据集,包含7个父类和134子类,每个样本对应一个父类标签和子类标签,父类标签和子类标签间具有层次结构关系,WOS数据集已内置到PaddleNLP中。

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

,WOS数据集已内置到PaddleNLP中。 这句话可以去掉

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已去掉

单卡训练
```shell
$ unset CUDA_VISIBLE_DEVICES
$ python train.py --early_stop
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果是单卡训练不应该 unset CUDA_VISIBLE_DEVICES , 多卡的时候需要unset一下

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已去掉

指定GPU卡号/多卡训练
```shell
$ unset CUDA_VISIBLE_DEVICES
$ python -m paddle.distributed.launch --gpus "0,1" train.py --early_stop
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里建议 --gpus "0,1" -> --gpus "0", 因为大多数用户是没有多卡,可以备注一下 如果用多卡,可以指定0,1,这样的数字

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改:

指定GPU卡号/多卡训练

unset CUDA_VISIBLE_DEVICES
python -m paddle.distributed.launch --gpus "0" train.py --early_stop

使用多卡训练可以指定多个GPU卡号,例如 --gpus "0,1"

**NOTE:**
* 如需恢复模型训练,则可以设置 `init_from_ckpt` , 如 `init_from_ckpt=checkpoints/macro/model_state.pdparams` 。
* 如需训练中文层次分类任务,只需更换预训练模型参数 `model_name` 。中文训练任务推荐使用"ernie-3.0-base-zh",更多可选模型可参考[Transformer预训练模型](https://paddlenlp.readthedocs.io/zh/latest/model_zoo/index.html#transformer)。
* 如需使用ernie-tiny模型,则需要提前先安装sentencepiece依赖,如 `pip install sentencepiece`。
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这块话去掉把,后续我们建议用户使用的ERNIE-3.0-tiny,使用的tokenizer是不需要sentencepiece

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改


程序运行时将会自动进行训练,评估,测试。同时训练过程中会自动保存开发集上最佳 Macro F1 值和最佳 Micro F1 值的模型在指定的 `save_dir` 中,保存模型文件结构如下所示:

```text
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里有个疑问,一般情况下如果选择macro和micro两个其中的模型了

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改,默认保存最佳macro f1模型参数


* `params_path`:动态图训练保存的参数路径;默认为"./checkpoint/macro/model_state.pdparams"。
* `output_path`:静态图图保存的参数路径;默认为"./export"。
* `dataset`:训练数据集;默认为wos数据集。
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果是导出模型的话应该是不需要dataset

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改,使用num_classes来确定AutoModelForSequenceClassification中类别数:

  • num_classes:任务标签类别数;默认为wos数据集类别数141。


```shell
python deploy/paddle2onnx/infer.py --model_path_prefix ./export/wos/1/float32
```
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个脚本的模型输入会有点怪,因为这个时候还没有进行模型的裁减的工作

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

示例export_model.py导出的静态图模型可以用于onnxruntime推理


启动裁剪:
```shell
$ python prune.py --output_dir ./export
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

裁剪这块的输入模型是什么了?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

启动裁剪:

python prune.py --output_dir ./prune --params_dir ./checkpoint/model_state.pdparams

* `params_dir`:待预测模型参数文件;默认为"./checkpoint/macro/model_state.pdparams"。
* `model_name_or_path`:选择预训练模型;默认为"bert-base-uncased"。

以上参数都可通过 `python xxx.py --dataset xx --params_dir xx` 的方式传入)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

xxx.py -> prune.py

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改:
以上参数都可通过 python prune.py --dataset xx --params_dir xx 的方式传入)


1. 数据集:WOS(英文层次分类数据集)

2. 计算卡:V100、CUDA11.2
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1.物理机环境
系统: CentOS Linux release 7.5.1804
GPU: Tesla V100-SXM2-32GB * 8
CPU: Intel(R) Xeon(R) Gold 6148 CPU @ 2.40GHz * 40
CUDA: 11
cuDNN: 8.0.4
Driver Version: 450.80.02
内存: 502 GB

GPU相关信息再细致一点,参考上面

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改:

  1. 物理机环境

    系统: CentOS Linux release 7.7.1908 (Core)

    GPU: Tesla V100-SXM2-32GB * 8

    CPU: Intel(R) Xeon(R) Gold 6271C CPU @ 2.60GHz

    CUDA: 11.2

    cuDNN: 8.1.0

    Driver Version: 460.27.04

    内存: 630 GB

| BERT base | 86.06 | 81.29 | 8.80 |
| BERT base+裁剪(3/4) | 86.83(+0.77) | 81.08(-0.21) | 6.85 |
| BERT base+裁剪(2/3) | 86.77(+0.71) | 80.48(-0.81) | 5.98 |
| BERT base+裁剪(1/4) | 86.40(+0.34) | 80.79(-0.5) | 2.51 |
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这块可能最好是基于ERNIE-2.0-en来做实验,这里的实验细节没有看明白, BERT base+裁剪(3/4) 裁剪越多,latency最高?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

裁剪后面数值为保留比例,已更换ernie-2.0-base-en进行实验:

Micro F1 Macro F1 latency(ms)
ERNIE 2.0 85.71 80.82 8.80
ERNIE 2.0+裁剪(保留比例3/4) 86.83(+1.12) 81.78(+0.96) 6.85
ERNIE 2.0+裁剪(保留比例2/3) 86.74(+1.03) 81.64(+0.82) 5.98
ERNIE 2.0+裁剪(保留比例1/4) 85.79(+0.08) 79.53(-1.29) 2.51

# See the License for the specific language governing permissions and
# limitations under the License.

import paddle
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的import要符合PE8, 这块可以看看
一般按照这个规则来,标准库模块,第三方模块,自用模块
paddle这种一般都是向后排

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

)
parser.add_argument(
'--model_name',
default="bert-base-uncased",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这块换成ERNIE EN系列的模型

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

if paddle.distributed.get_world_size() > 1:
paddle.distributed.init_parallel_env()

# load and preprocess dataset
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

注释手字母大写,整体都改一下了

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改


# batchify dataset
collate_fn = DataCollatorWithPadding(tokenizer)
train_batch_sampler = BatchSampler(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这块有点问题,如果是多卡时,应该使用的DistributedBatchSampler

Copy link
Contributor Author

@lugimzzz lugimzzz Jun 20, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改:

if paddle.distributed.get_world_size() > 1:
    train_batch_sampler = DistributedBatchSampler(train_ds, batch_size=args.batch_size, shuffle=True)
else:
    train_batch_sampler = BatchSampler(train_ds, batch_size=args.batch_size, shuffle=True)

lr_scheduler.step()
optimizer.clear_grad()

global_step += 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是因为step是从1开始,这块的global_step这块的逻辑是不是有点不太对

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

global_step是从0开始

early_stop_count = 0
best_micro_f1_score = micro_f1_score
model._layers.save_pretrained(save_best_micro_path)
tokenizer.save_pretrained(save_best_micro_path)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Save的逻辑可能再讨论一下,是不是mean值会更好,放出太多的模型会让用户比较迷

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

选择保留最佳macro f1模型

tokenizer.save_pretrained(save_best_micro_path)


def test():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里比较有疑问的是,和evaluate的区别是什么,看这里加载的模型是marco f1值最好的模型,这里的考虑是什么了

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test()是打算评测测试集表现,已去掉test函数

if step % 100 == 0:
logger.info("step %d, %d samples processed" %
(step, step * args.batch_size))
metric.report()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的report里面看起来是使用了sklearn相关的函数,在requirement没有体现出来

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已新加requirements.txt

from paddlenlp.utils.log import logger


# 构建验证集evaluate函数
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

注释改成英文注释,统一一下

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

# See the License for the specific language governing permissions and
# limitations under the License.

import paddle
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

import关系要符合一下PE8标准

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

@paddle.no_grad()
def evaluate(model, criterion, metric, data_loader):
"""
Given a dataset, it evals model and computes the metric.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

evals看起来没有这个单测,eval->evaluate

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改为evaluates

import yaml
import functools
from typing import Optional
import paddle
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这块的import关系再check一下

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

from metric import MetricReport

nn.MultiHeadAttention._ori_forward = paddle.nn.MultiHeadAttention.forward
nn.MultiHeadAttention._ori_prepare_qkv = nn.MultiHeadAttention._prepare_qkv
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这块的函数变换,可以问一下佳琪的原因,然后注释一下

Copy link
Contributor Author

@lugimzzz lugimzzz Jun 20, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已添加注释:

# Paddleslim will modify MultiHeadAttention.forward and MultiHeadAttention._prepare_qkv
# Original forward and _prepare_qkv should be saved before import paddleslim
nn.MultiHeadAttention._ori_forward = paddle.nn.MultiHeadAttention.forward
nn.MultiHeadAttention._ori_prepare_qkv = nn.MultiHeadAttention._prepare_qkv

@@ -0,0 +1,58 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这块的文件改一下吧,和之前的一些保持一致,export.py->export_model.py

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改脚本名

help="The path to model parameters to be loaded.")
parser.add_argument("--output_path", type=str, default='./export',
help="The path of model parameter in static graph to be saved.")
parser.add_argument("--dataset", default="wos", type=str,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dataset没有使用,可以去掉

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

args = parser.parse_args()


def predict(data, label_list):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

加一个@paddle.no_grad,节省显存占用

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已加上

results.append(labels)

for idx, text in enumerate(data):
label_name = [label_list[r] for r in results[idx]]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里对结果进行解析的过程中,回到之前的问题,因为这次是一个层次分类的任务,最终输出看起来是多标签分类的结果

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

for r in results[idx]:
    if r < 7:
        level1.append(label_list[r])
    else:
        level2.append(label_list[r])
        print('predicted result:')
        print('level 1 : {} level 2 : {}'.format(', '.join(level1), ', '.join(
            level2)))

## 环境准备

模型转换与ONNXRuntime预测部署依赖Paddle2ONNX和ONNXRuntime,Paddle2ONNX支持将Paddle模型转化为ONNX模型格式,算子目前稳定支持导出ONNX Opset 7~15,更多细节可参考:[Paddle2ONNX](https://github.com/PaddlePaddle/Paddle2ONNX)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的依赖安装是不是也要增加一下paddle2onnx

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

paddlenlp安装的时候已经安装paddle2onnx了


2. 计算卡:V100、CUDA11.2

3. CPU 信息:Intel(R) Xeon(R) Gold 6271C CPU
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的硬件设备信息,根据trianer再正式一点

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改:
2. 物理机环境

系统: CentOS Linux release 7.7.1908 (Core)

GPU: Tesla V100-SXM2-32GB * 8

CPU: Intel(R) Xeon(R) Gold 6271C CPU @ 2.60GHz

CUDA: 11.2

cuDNN: 8.1.0

Driver Version: 460.27.04

内存: 630 GB


import paddle
import argparse
from predictor import Predictor
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这块的import往下摞

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

import onnxruntime as ort
from paddlenlp.transformers import AutoTokenizer
import paddle.nn.functional as F
from sklearn.metrics import f1_score
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

咱们的框架是有F1只算的 AccuracyAndF1

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sklearn可以计算macro和micro f1

def printer(self, infer_result, input_datas):
label = infer_result["label"]
confidence = infer_result["confidence"]
for i, input_data in enumerate(input_datas):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里看看,能不能改成logger

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

@lugimzzz
Copy link
Contributor Author

按照comment修改了代码和readme,基于ernie-2.0-en的性能和精度评测结果后续补上,serving部署部分明后两天补上。
裁剪部分的参数width_mul还是按照保留比例设置,因为涉及改的代码比较多,在readme上性能和精度评测结果栏里注明了是裁剪保留比例。

@lugimzzz
Copy link
Contributor Author

已增加triton serving部署部分代码和文档说明

@lugimzzz
Copy link
Contributor Author

完善triton serving部分,并补充ERNIE 2.0在wos数据集上表现和性能(裁剪,fp16,int8)

@lugimzzz
Copy link
Contributor Author

由于pre-commit-config.yaml变化,更新了yapf版本,导致代码codestyle不通过。通过merge develop分支最新代码,更新pre-commit文件

@lugimzzz
Copy link
Contributor Author

lugimzzz commented Jun 23, 2022

新增层次分类读取本地数据方式,修改paddlenlp/dataset/wos.py以支持以内置数据集格式读取本地数据集。目前本地数据集读取支持树状图和有向无环图标签层次结构,支持数据集层次标签不同深度,用训练、预测、部署仅需提供设定格式本地数据集路径和参数配置。

@ArtificialZeng
Copy link
Contributor

新增层次分类读取本地数据方式,修改paddlenlp/dataset/wos.py以支持以内置数据集格式读取本地数据集。目前本地数据集读取支持树状图和有向无环图标签层次结构,支持数据集层次标签不同深度,用训练、预测、部署仅需提供设定格式本地数据集路径和参数配置。

Traceback (most recent call last):
File "E:\A_IDE\Anaconda3\envs\pp212_gpu\lib\threading.py", line 926, in _bootstrap_inner
self.run()
File "E:\A_IDE\Anaconda3\envs\pp212_gpu\lib\threading.py", line 870, in run
self._target(*self._args, **self._kwargs)
File "E:\A_IDE\Anaconda3\envs\pp212_gpu\lib\site-packages\paddle\fluid\dataloader\dataloader_iter.py", line 216, in _thread_loop
self._thread_done_event)
File "E:\A_IDE\Anaconda3\envs\pp212_gpu\lib\site-packages\paddle\fluid\dataloader\fetcher.py", line 121, in fetch
data.append(self.dataset[idx])
File "E:\A_IDE\Anaconda3\envs\pp212_gpu\lib\site-packages\paddlenlp\datasets\dataset.py", line 272, in getitem
idx]) if self._transform_pipline else self.new_data[idx]
File "E:\A_IDE\Anaconda3\envs\pp212_gpu\lib\site-packages\paddlenlp\datasets\dataset.py", line 263, in _transform
data = fn(data)
File "G:\googleDownload\PaddleNLP-hierarchical\PaddleNLP-hierarchical\applications\text_classification\hierarchical_classification\utils.py", line 73, in preprocess_function
for i in range(label_nums)
File "G:\googleDownload\PaddleNLP-hierarchical\PaddleNLP-hierarchical\applications\text_classification\hierarchical_classification\utils.py", line 73, in
for i in range(label_nums)
KeyError: 'label'

虽然数据可以读取本地了, 但还是报错

@lugimzzz
Copy link
Contributor Author

hierarchical\applications\text_classification\hierarchical_classification\utils.py", line 73, in preprocess_function

utils中preprocess_function.py已经发生变动,如需使用请使用最新代码(其他多个文件也在输入输出部分出现变化,请一并拉取)

@ArtificialZeng
Copy link
Contributor

新增层次分类读取本地数据方式,修改paddlenlp/dataset/wos.py以支持以内置数据集格式读取本地数据集。目前本地数据集读取支持树状图和有向无环图标签层次结构,支持数据集层次标签不同深度,用训练、预测、部署仅需提供设定格式本地数据集路径和参数配置。

W0624 11:18:51.985149 25272 gpu_context.cc:306] device: 0, cuDNN Version: 8.0.
Exception in thread Thread-4:
Traceback (most recent call last):
File "E:\A_IDE\Anaconda3\envs\pp212_gpu\lib\threading.py", line 926, in _bootstrap_inner
self.run()
File "E:\A_IDE\Anaconda3\envs\pp212_gpu\lib\threading.py", line 870, in run
self._target(*self._args, **self._kwargs)
File "E:\A_IDE\Anaconda3\envs\pp212_gpu\lib\site-packages\paddle\fluid\dataloader\dataloader_iter.py", line 216, in _thread_loop
self._thread_done_event)
File "E:\A_IDE\Anaconda3\envs\pp212_gpu\lib\site-packages\paddle\fluid\dataloader\fetcher.py", line 121, in fetch
data.append(self.dataset[idx])
File "E:\A_IDE\Anaconda3\envs\pp212_gpu\lib\site-packages\paddlenlp\datasets\dataset.py", line 268, in getitem
) if self._transform_pipline else self.new_data[idx]
File "E:\A_IDE\Anaconda3\envs\pp212_gpu\lib\site-packages\paddlenlp\datasets\dataset.py", line 258, in _transform
data = fn(data)
TypeError: preprocess_function() got an unexpected keyword argument 'label_nums'

bug是不是有点太多了

@ArtificialZeng
Copy link
Contributor

hierarchical\applications\text_classification\hierarchical_classification\utils.py", line 73, in preprocess_function

utils中preprocess_function.py已经发生变动,如需使用请使用最新代码(其他多个文件也在输入输出部分出现变化,请一并拉取)

谢谢,大佬, 拉取最新的后 跑通了。

请问可以应用于商品同款匹配嘛

@lugimzzz
Copy link
Contributor Author

本层次文本分类任务目前主要用于对文本片段进行分类,确定文本在不同层次标签中所属的类别。同款商品匹配不建议使用文本分类,中文商品匹配可以参考application中语义索引任务进行搭建

@ArtificialZeng
Copy link
Contributor

本层次文本分类任务目前主要用于对文本片段进行分类,确定文本在不同层次标签中所属的类别。同款商品匹配不建议使用文本分类,中文商品匹配可以参考application中语义索引任务进行搭建

什么时候有中文的 文本分类, 怎么示例用英文的呢?
以下网站有一个中文专利的: 要是下载不了可以留邮箱 我发您
https://opendata.pku.edu.cn/dataset.xhtml?persistentId=doi:10.18170/DVN/ASRTHL&version=2.0

@lugimzzz
Copy link
Contributor Author

什么时候有中文的 文本分类, 怎么示例用英文的呢?
考虑到目前层次分类相关中文数据集非正式数据集,所以层次分类使用英文数据集进行示例,之后层次分类中文数据集会以notebook的形式呈现,放入到文档中。之后多标签,多分类的文本分类都采用中文数据集示例。

@lugimzzz
Copy link
Contributor Author

加入Paddle Serving部署

Copy link
Collaborator

@wawltor wawltor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@kevindragon
Copy link

运行的时候报错:
image

python -m paddle.distributed.launch --gpus 0 applications/text_classification/hierarchical_classification/train.py --early_stop

运行目录为根目录,版本信息如下:

paddlepaddle 2.3.0
paddlenlp 2.3.0.dev

@lugimzzz
Copy link
Contributor Author

运行的时候报错: image

python -m paddle.distributed.launch --gpus 0 applications/text_classification/hierarchical_classification/train.py --early_stop

因为paddlenlp/dataset/wos.py有变化,目前代码还未merge到develop分支中,如果要加载wos数据集,建议替换paddlenlp安装目录下paddlenlp/dataset/wos.py
note:可以使用pip show paddlenlp确定安装location.

@ArtificialZeng
Copy link
Contributor

LGTM

什么时候合并到主分支 ,刚在本地运行成功了

@PaddlePaddle PaddlePaddle locked and limited conversation to collaborators Jun 27, 2022
@PaddlePaddle PaddlePaddle unlocked this conversation Jun 27, 2022
@wawltor wawltor closed this Jun 27, 2022
@wawltor wawltor reopened this Jun 27, 2022
@wawltor wawltor merged commit 08f0633 into PaddlePaddle:develop Jun 27, 2022
@lugimzzz lugimzzz deleted the hierarchical branch July 4, 2022 09:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants