Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sentiment Analysis Application #1505

Merged
merged 37 commits into from
Dec 27, 2021
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
38786a7
sentiment analysis initializing
1649759610 Dec 13, 2021
d458c55
modify README.md
1649759610 Dec 13, 2021
2785b2b
sentiment analysis initializing
1649759610 Dec 23, 2021
227cb1e
modify data_ext and data_cls link in README.md
1649759610 Dec 23, 2021
d82cb13
sentiment analysis initializing
1649759610 Dec 23, 2021
6e5d9b2
sentiment analysis initializing
1649759610 Dec 23, 2021
4d12fd7
sentiment analysis initializing
1649759610 Dec 23, 2021
df7e55f
delete unuseful info.
1649759610 Dec 23, 2021
a45c887
sentiment analysis initializing
1649759610 Dec 23, 2021
79ac7c5
sentiment analysis initializing
1649759610 Dec 23, 2021
c40edd0
sentiment analysis initializing
1649759610 Dec 23, 2021
dce2c24
sentiment analysis initializing
1649759610 Dec 23, 2021
8b18b36
delete sentiment_system.png
1649759610 Dec 24, 2021
3d9429a
add sentiment_system.png
1649759610 Dec 24, 2021
dce8380
refine readme.md
1649759610 Dec 24, 2021
0fbea92
Merge branch 'develop' into mynlp
ZeyuChen Dec 24, 2021
fe64135
sentiment analysis intializing
1649759610 Dec 24, 2021
a9befda
Merge branch 'mynlp' of https://github.com/1649759610/PaddleNLP into …
1649759610 Dec 24, 2021
07fa679
mv data and checkpoints in sub_dir to parent_dir
1649759610 Dec 24, 2021
0dda6a0
refine readme.md
1649759610 Dec 24, 2021
1c30c3e
refine readme
1649759610 Dec 24, 2021
50aa913
refine readme.md, modify requirements
1649759610 Dec 24, 2021
1b34ee6
refine readme.md
1649759610 Dec 24, 2021
d518b37
refine readme.md
1649759610 Dec 24, 2021
e15cd91
refine readme.md
1649759610 Dec 24, 2021
e21659e
mv run_export.sh to run_export_model.sh
1649759610 Dec 24, 2021
c06169e
refine readme.md
1649759610 Dec 26, 2021
240af3f
remove some
1649759610 Dec 26, 2021
0bca33e
remove some unnecessary packages
1649759610 Dec 26, 2021
56af6c2
delete unuseful compute_md5 method
1649759610 Dec 26, 2021
95982d6
sentiment analysis initializing
1649759610 Dec 26, 2021
f9f4b51
sentiment analysis initializing
1649759610 Dec 26, 2021
0eddb94
refine readme.md
1649759610 Dec 26, 2021
f7f052b
set CUDA_VISIBLE_DEVICES=0
1649759610 Dec 26, 2021
13e54c7
refine style with pre-commit
1649759610 Dec 26, 2021
584934a
use ppminilm in ernie instead of offline ppminilm
1649759610 Dec 26, 2021
9275b4d
Merge branch 'develop' into mynlp
chenxiaozeng Dec 27, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 142 additions & 0 deletions applications/sentiment_analysis/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
# 情感分析
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

开头加下目录吧


## 1. 场景概述

文本情感分析:又称意见挖掘、倾向性分析等。简单而言,是对带有情感色彩的主观性文本进行分析、处理、归纳和推理的过程。互联网(如博客和论坛以及社会服务网络如大众点评)上产生了大量的用户参与的、对于诸如人物、事件、产品等有价值的评论信息。这些评论信息表达了人们的各种情感色彩和情感倾向性,如喜、怒、哀、乐和批评、赞扬等。基于此,潜在的用户就可以通过浏览这些主观色彩的评论来了解大众舆论对于某一事件或产品的看法。[1]

根据分析粒度的不同,情感分析往往可以被分为篇章级、语句级和词语级别,其分别主要分析文章,语句和词语中所蕴含的感情色彩。一般来讲,被人们所熟知的情感分析任务是语句级别的情感分析,例如下边这句话。

> 15.4寸笔记本的键盘确实爽,基本跟台式机差不多了,蛮喜欢数字小键盘,输数字特方便,样子也很美观,做工也相当不错。

显然,这句话对于该笔记本的评论是积极正向的,说明顾客对该笔记本的设计很满意,特别是其键盘。语句级别的情感分析就是在分析一句文本中所蕴含的情感倾向是积极的,还是消极的。

但是在人们进行评论的时候,往往针对某一对象进行多个维度的评论,该任务只能在宏观上去分析整句话的感情色彩,却不能细粒度的去分析语句中各个维度的情感倾向。而后者在真实的场景应用中更加使用,同时更能给到企业用户或商家更加具体的建议。例如这句关于薯片的评论。

> 这个薯片味道真的太好了,口感很脆,只是包装很一般。

可以看到,顾客在口感、包装和味道 三个维度上对薯片进行了评价,顾客在味道和口感两个方面给出了好评,但是在包装上给出了负面的评价。只有通过这种比较细粒度的分析,商家才能更有针对性的发现问题,进而改进自己的产品。

本项目基于这样的考量,提供出一套完整的细粒度情感分析解决方案,期望能够在评论语句中评论维度的粒度进行情感分析。

## 2. 产品功能介绍

### 2.1 系统特色

- 低门槛
- 手把手搭建起 细粒度 情感分析系统:抽取语句中评论维度以及相应观点,并基于评价维度的粒度进行情感分析。
- 提供训练、预测、部署一站式能力。
- 效果好
- 基于情感分析建模的专属模型SKEP,解决通用模型对情感信息不敏感的痛点。

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

中英文字符间需要 1 个空格

- 性能优
- 提供开源小模型以及配套优化策略,以解决大模型预测效率问题。

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • 提供开源小模型及量化加速方案,大幅提升预测性能。


### 2.2 架构&功能

针对以上提到的细粒度情感分析,我们提出的解决方案如下图所示。整个情感分析的过程包含两个阶段,依次是评论维度和观点抽取模型,细粒度情感分类模型。对于给定的一段文本,首先基于前者抽取出文本语句中潜在的评论维度以及该维度相应的评论观点,然后将评论维度、观点以及原始文本进行拼接,传给细粒度情感分类模型以识别出该评论维度的情感色彩。
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

评论维度 统一成 评价维度 吧


这里需要提到的是,由于目前市面上的大多数模型是基于通用语料训练出来的,这些模型可能并不会对情感信息那么敏感。基于这样的考量,本项目使用了百度自研的SKEP预训练模型,其在预训练阶段便设计了多种情感信息相关的预训练目标进行训练。作为一种情感专属的模型,其更适合用来做上边提到的评论维度和观点抽取任务,以及细粒度情感分类任务。

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

中英文字符间需要 1 个空格


另外,本项目使用的是Large版的SKEP模型,考虑到企业用户在线上部署时会考虑到模型预测效率,所以本项目专门提供了一个通用版的小模型[PP-MiniLM](https://github.com/LiuChiachi/PaddleNLP/tree/add-ppminilm/examples/model_compression/PP-MiniLM)以及一套量化策略,用户可以使用相应情感数据集对PP-MinLM进行微调,然后进行量化,以达到更快的使用效率。

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

中英文字符间需要 1 个空格

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PP-MiniLM 要链接 PaddleNLP 的官方地址


<center> <img src="./imgs/sentiment_system.png" /></center>
<center>图1 情感分析系统图 </center>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

图说明需要居中下



## 3. 细粒度情感分析实践

以下是本项目运行的完整目录结构以及说明:

```shell
.
├── extraction # 评价维度和观点抽取模型包
├── classification # 细粒度情感分类模型包
├── ppminilm # PP-MiniLM特色小模型包
├── data # 全流程批量预测测试集,批量预测结果集
├── imgs # 图片
├── dynamic_predict.py # 全流程动态图单条预测脚本
├── dynamic_predict_by_batch.py # 全流程动态图批量预测脚本
├── export_model.py # 动转静模型导出脚本
├── static_predict.py # 全流程静态图单条预测脚本
├── run_dynamic_predict.sh # 全流程动态图单条预测命令
├── run_dynamic_predict_by_batch.sh # 全流程动态图批量预测命令
├── run_export.sh # 动转静模型导出命令
├── run_static_predict.sh # 全流程静态图单条预测命令
├── requirements.txt # 环境依赖
└── README.md
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shell脚本和python脚本命名建议统一下;是否单建一个/predict的目录呢

```

### 3.1 运行环境和依赖安装
(1) 运行环境
除非特殊说明,本实验默认是在以下配置环境研发运行的:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这样的话,可以加一个环境依赖,python、paddlepaddle、paddlenlp等所有library的版本依赖。版本依赖一般是xx以上吧
格式比如:
python >= 3.6
paddlepaddle >= 2.2

```shell
python == 3.8
CUDA Version: 10.2
NVIDIA Driver Version: 440.64.00
GPU: Tesla V100

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

只需要 1 个空格吧?

linux:CentOS Linux release 7.9.2009 (Core)
```
(2) 依赖安装
可以通过以下命令进行软件环境安装:
```shell
pip install -r requirements.txt
```

### 3.2 数据说明
本项目需要训练两个阶段的模型:评论维度和观点抽取模型,细粒度情感分类模型。本次针对这抽取和分类模型,我们分别开源了Deomo数据:[data_ext](https://bj.bcebos.com/v1/paddlenlp/data/data_ext.tar.gz)和[data_cls](https://bj.bcebos.com/v1/paddlenlp/data/data_cls.tar.gz)。

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Deomo -> Demo


用户可分别点击下载,解压后将数据文件依次放入`./extraction/data`和`./classification/data`目录下即可。

### 3.3 评论维度和观点抽取模型
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

效果很赞的模型开源是一大亮点,建议把模型放到父目录来

关于评论维度和观点抽取模型的原理和使用方式,请参考[这里](extraction/README.md)。

### 3.4 细粒度情感分类模型
关于细粒度情感分类模型的原理和使用方式,请参考[这里](classification/README.md)


### 3.5 全流程细粒度情感分析推理
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

最好给一下输入、输出示例

在训练完成评论维度和观点模型,细粒度情感分类模型后,默认会将训练过程中最好的模型保存在`./extraction/checkpoints`和`./classification/checkpoints`目录下。接下来,便可以根据保存好的模型进行全流程的模型推理:给定一句评论文本,首先使用抽取模型进行抽取评论维度和观点,然后使用细粒度情感分类模型以评论维度级别进行情感极性分类。

本项目将提供两套全流程预测方案:动态图预测和静态图高性能预测,其中动态图预测支持单条和批量预测两种方式。

#### 3.5.1 全流程动态图预测
通过运行以下命令进行全流程动态图单条预测:
```shell
sh run_dynamic_predict.sh
```

通过运行以下命令进行动态图批量预测:
```shell
sh run_dynamic_predict_by_batch.sh
```
**备注**:动态图批量预测时需要传入测试集文件路径,可将测试集文件放入本目录的`data`文件夹下,模型在预测后会将结果以文件的形式存入测试集的同目录下。需要注意的是,测试集文件每行均为一个待预测的语句,如下所示。
```
蛋糕味道不错,很好吃,店家很耐心,服务也很好,很棒
酒店干净整洁,性价比很高
酒店环境不错,非常安静,性价比还可以
房间很大,环境不错
```

#### 3.5.2 静态图高性能预测
在基于静态图进行高性能预测过程中,首先需要将动态图模型转换为静态图模型,然后基于Paddle Inference 高性能推理引擎进行预测。

通过以下命令将动态图转为静态图:
```shell
sh run_export.sh
```

基于Paddle Inference 进行动态图高性能预测:
```shell
sh run_static_predict.sh
```

### 3.6 小模型优化策略
本项目提供了一套基于[PP-MiniLM](https://github.com/LiuChiachi/PaddleNLP/tree/add-ppminilm/examples/model_compression/PP-MiniLM)中文特色小模型的细粒度情感分类解决方案。PP-MiniLM提供了一套完整的小模型优化方案:首先使用Task-agnostic的方式进行模型蒸馏、然后依托于[PaddleSlim](https://github.com/PaddlePaddle/PaddleSlim) 进行模型裁剪、模型量化等模型压缩技术,有效减小了模型的规模,加快了模型运行速度。

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PP-MiniLM 链接到官方地址


本项目基于PP-MiniLM中文特色小模型进行fine-tune细粒度情感分类模型,然后使用PaddleSlim对训练好的模型进行量化操作。详细信息请参考[这里](./ppminilm/README.md)。


## 4. 引用

[1] [文本情感分析](https://baike.baidu.com/item/%E6%96%87%E6%9C%AC%E6%83%85%E6%84%9F%E5%88%86%E6%9E%90/19431243?fr=aladdin)

[2] [SKEP论文](https://aclanthology.org/2020.acl-main.374.pdf)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

论文的引用格式是否规范一下?

67 changes: 67 additions & 0 deletions applications/sentiment_analysis/classification/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# 细粒度情感分类模型



## 1. 方案设计

本项目将针对语句中的维度级别进行情感分析,对于给定的一段文本,我们在基于评论维度和观点抽取模型抽取出评论维度和观点后,便可以有针对性的对各个维度进行评论。具体来讲,本实践将抽取出的评论维度和评论观点进行拼接,然后原始语句进行拼接作为一条独立的训练语句。

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

然后原始语句 -> 然后和原始语句


如图1所示,首先将评论维度和观点词进行拼接为"味道好",然后将"味道好"和原文进行拼接,然后传入SKEP模型,并使用"CLS"位置的向量进行细粒度情感倾向。

<center><img src="../imgs/design_cls_model.png" /></center>

<br><center>图1 细粒度情感分类模型</center><br/>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

居中下,其他同


## 2. 项目结构说明

以下是本项目运行的完整目录结构及说明:

```shell
.
├── data # 数据目录
├── checkpoints # 模型保存目录
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是否需要checkpoints ?

│   └── static # 静态图模型保存目录
├── data.py # 数据处理脚本
├── model.py # 模型组网脚本
├── train.py # 模型训练脚本
├── evaluate.py # 模型评估脚本
├── utils.py # 工具函数
├── run_train.sh # 模型训练命令
├── run_evaluate.sh # 模型评估命令
└── README.md
```

## 3. 数据说明

本模型将基于评论维度和观点进行细粒度的情感分析,因此数据集中需要包含3列数据:文本串和相应的序列标签数据,下面给出了一条样本,其中第1列是情感标签,第2列是评论维度和观点,第3列是原文。

> 1 口味清淡 口味很清淡,价格也比较公道

可点击[data_cls](https://bj.bcebos.com/v1/paddlenlp/data/data_ext.tar.gz)进行Demo数据下载,将数据解压之后放入本目录的`data`文件夹下。

## 4. 模型效果展示

在分类模型训练过程中,总共训练了10轮,并选择了评估F1得分最高的best模型, 更加详细的训练参数设置如下表所示:
|Model|训练参数配置|硬件|MD5|
| ------------ | ------------ | ------------ |-----------|
|[cls_model](https://bj.bcebos.com/paddlenlp/models/best_cls.pdparams)|<div style="width: 150pt"> learning_rate: 3e-5, batch_size: 16, max_seq_len:256, epochs:10 </div>|<div style="width: 100pt">Tesla V100-32g</div>|3de6ddf581e665d9b1d035c29b49778a|

我们基于训练过程中的best模型在验证集`dev_set`和测试集`test_set`上进行了评估测试,模型效果如下表所示:
|Model|数据集|precision|Recall|F1|
| ------------ | ------------ | ------------ |-----------|------------ |
|SKEP-Large|dev_set|0.98758|0.99251|0.99004|
|SKEP-Large|test_set|0.98497|0.99139|0.98817|

**备注**:以上数据是基于全量数据训练和测试结果,并非Demo数据集。

## 5. 模型训练
通过运行以下命令进行分类模型训练:
```shell
sh run_train.sh
```

## 6. 模型测试
通过运行以下命令进行分类模型测试:
```shell
sh run_evaluate.sh
```
59 changes: 59 additions & 0 deletions applications/sentiment_analysis/classification/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
from tqdm import tqdm


def load_dict(dict_path):
with open(dict_path, "r", encoding="utf-8") as f:
words = [word.strip() for word in f.readlines()]
word2id = dict(zip(words, range(len(words))))
id2word = dict((v, k) for k, v in word2id.items())

return word2id, id2word


def read(data_path):
with open(data_path, "r", encoding="utf-8") as f:
for line in f.readlines():
items = line.strip().split("\t")
assert len(items) == 3
example = {
"label": int(items[0]),
"target_text": items[1],
"text": items[2]
}

yield example


def convert_example_to_feature(example,
tokenizer,
label2id,
max_seq_len=512,
is_test=False):
encoded_inputs = tokenizer(
example["target_text"],
text_pair=example["text"],
max_seq_len=max_seq_len,
return_length=True)

if not is_test:
label = example["label"]
return encoded_inputs["input_ids"], encoded_inputs[
"token_type_ids"], encoded_inputs["seq_len"], label

return encoded_inputs["input_ids"], encoded_inputs[
"token_type_ids"], encoded_inputs["seq_len"]
82 changes: 82 additions & 0 deletions applications/sentiment_analysis/classification/evaluate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
from tqdm import tqdm
from functools import partial
import paddle
from paddlenlp.data import Pad, Stack, Tuple
from paddlenlp.metrics.glue import AccuracyAndF1
from paddlenlp.datasets import load_dataset
from paddlenlp.transformers import SkepModel, SkepTokenizer
from model import SkepForSequenceClassification
from data import read, load_dict, convert_example_to_feature


def evaluate(model, data_loader, metric):

model.eval()
metric.reset()
for batch_data in tqdm(data_loader):
input_ids, token_type_ids, _, labels = batch_data
logits = model(input_ids, token_type_ids=token_type_ids)
correct = metric.compute(logits, labels)
metric.update(correct)

accuracy, precision, recall, f1, _ = metric.accumulate()

return accuracy, precision, recall, f1


if __name__ == "__main__":
# yapf: disable
parser = argparse.ArgumentParser()
parser.add_argument("--model_path", type=str, default=None, help="The path of saved model that you want to load.")
parser.add_argument('--test_path', type=str, default=None, help="The path of test set.")
parser.add_argument("--label_path", type=str, default=None, help="The path of label dict.")
parser.add_argument("--batch_size", type=int, default=32, help="Batch size per GPU/CPU for training.")
parser.add_argument("--max_seq_len", type=int, default=512, help="The maximum total input sequence length after tokenization.")
args = parser.parse_args()
# yapf: enbale

# load dev data
model_name = "skep_ernie_1.0_large_ch"
label2id, id2label = load_dict(args.label_path)
test_ds = load_dataset(read, data_path=args.test_path, lazy=False)

tokenizer = SkepTokenizer.from_pretrained(model_name)
trans_func = partial(convert_example_to_feature, tokenizer=tokenizer, label2id=label2id, max_seq_len=args.max_seq_len)
test_ds = test_ds.map(trans_func, lazy=False)

batchify_fn = lambda samples, fn=Tuple(
Pad(axis=0, pad_val=tokenizer.pad_token_id),
Pad(axis=0, pad_val=tokenizer.pad_token_type_id),
Stack(dtype="int64"),
Stack(dtype="int64")
): fn(samples)

test_batch_sampler = paddle.io.BatchSampler(test_ds, batch_size=args.batch_size, shuffle=False)
test_loader = paddle.io.DataLoader(test_ds, batch_sampler=test_batch_sampler, collate_fn=batchify_fn)

# load model
loaded_state_dict = paddle.load(args.model_path)
skep = SkepModel.from_pretrained(model_name)
model = SkepForSequenceClassification(skep, num_classes=len(label2id))
model.load_dict(loaded_state_dict)

metric = AccuracyAndF1()

# evalute on dev data
accuracy, precision, recall, f1 = evaluate(model, test_loader, metric)
print(f'evalution result: accuracy:{accuracy:.5f} precision: {precision:.5f}, recall: {recall:.5f}, F1: {f1:.5f}')
42 changes: 42 additions & 0 deletions applications/sentiment_analysis/classification/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle


class SkepForSequenceClassification(paddle.nn.Layer):
def __init__(self, skep, num_classes=2, dropout=None):
super(SkepForSequenceClassification, self).__init__()
self.num_classes = num_classes
self.skep = skep
self.dropout = paddle.nn.Dropout(
dropout
if dropout is not None else self.skep.config["hidden_dropout_prob"])
self.classifier = paddle.nn.Linear(self.skep.config["hidden_size"],
num_classes)

def forward(self,
input_ids,
token_type_ids=None,
position_ids=None,
attention_mask=None):
_, pooled_output = self.skep(
input_ids,
token_type_ids=token_type_ids,
position_ids=position_ids,
attention_mask=attention_mask)

pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
return logits
Loading