Skip to content

MeanZhang/TextClassification

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

2 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

基于 TextCNN 模型的中文文本分类

1 项目目标

该项目的基本目标是用计算机程序对文本集按照一定的分类体系或标准进行自动分类标记。该程序根据一个已经被标注的训练文本集合,找到文本特征和文本类别之间的关系模型,然后利用这种学习得到的关系模型对新的文本进行类别判断 。

具体来说,在本项目中,使用 Python 编程语言,利用 Pytorch、Torchtext 等工具,训练出一个可对中文文本进行分类的 TextCNN 模型,将 THUCNews 中文文本数据集中的部分文本分为 10 个类别。

2 关键技术

2.1 文本预处理

文本预处理过程是提取文本中的关键词来表示文本的过程。中文文本预处理主要包括文本分词和去停用词两个阶段。

2.1.1 中文分词

对于英文文本,可以直接使用空格来分割单词。但是中文文本由汉字组成,除标点符号外没有其他明显的分割标志。而很多研究和实践表明,对于特征粒度,词粒度远好于字粒度。因此这里使用了 Python jieba 库进行中文分词。本项目重点在于文本分类,分词方法不做过多研究。

2.1.2 文本清洗

大多数情况下,中文文本中存在许多无实际意义的标点符号、介词、连词、语气词以及一些不易理解的词语,这些词语对于文本特征没有任何贡献。在处理文本时,一般需要去除这些词汇和符号。

此外,数据集中的文本可能夹杂有英文,而英文又有大小写之分,所以在处理时需要将英文统一转化为小写。

这里使用了一份停用词表来去除停用词,使用 Torchtext 提供的方法进行小写转换。

2.2 特征提取

经过文本预处理后的数据仍然不能无法直接输入模型,需要将文本转换为可以表示文本特征的向量。

特征提取的方法有很多,如 One-Hot 表示、词袋模型、TF-IDF 模型等。

这里使用了 Tortext 内置的方法 build_vocab,可以根据数据集内的词的频率等信息生成一个词典。可以根据这个词典构造出文本的向量表示。

2.3 TextCNN 模型

TextCNN1是一种用于文本分析的卷积神经网络(CNN)。相比于 RNN 和 LSTM 等传统模型,CNN 能更加高效的提取局部信息。在文本分类任务中,可以利用 CNN 来提取句子中类似 n-gram 的关键信息。TextCNN 利用多个不同 size 的卷积核提取句子中的关键信息(类似于多窗口的 N-Gram 模型),使用 Max-Pooling 选择出最具影响力的高维分类特征,再使用带有 Dropout 的全连接层提取文本深度特征,最后接 softmax 进行分类。

TextCNN 模型的具体使用在后文中会详细介绍。

3 核心思想和算法描述

3.1 核心思想

文本分类的核心思想是从文本中抽取出能够体现文本特点的关键特征,抓取特征到类别之间的映射。

从文本中提取特征在该项目中就是将文本通过词典转换为向量。

而特征到类别之间的映射就是训练后的模型,向模型输入文本转换后得到的向量后,输出 n 维向量(n 为分类数)。

3.2 算法描述

该项目训练流程如下:

st=>start: 开始
op1=>operation: 构造文本和标签Feild
io1=>inputoutput: 读取数据集,构造Dataset
op2=>operation: 用训练集和测试集建立词典
op3=>operation: 构造迭代器
op4=>operation: 将迭代器输入模型,进行训练
io2=>inputoutput: 训练完成,保存模型
op5=>operation: 进行测试
e=>end: 结束框
st->op1->io1->op2->op3->op4->io2->op5->e

其中 FeildDatasetTorchtext.data 中的类型,用于数据处理。

3.2.1 数据预处理

该模块的作用为对数据集的原始数据进行处理,转换为可以用于训练的迭代器。流程如下:

首先构造文本和标签的 Field,用于定义数据的格式规范和处理方法,其中对文本 Field 定义分词方法使用 jieba.lcut(),并且使用停用词。接着加载数据集中的数据,进行处理,构造 Dataset(即分词后的数据),并构建词典(用于生成向量)。然后使用 Dataset 分别构造训练集、验证集和测试集的迭代器 BucketIterator,同时指定 batch_sizes。这里选择了使用 BucketIterator,可以将长度相似的文本划分在同一 batch。

3.2.2 训练模块

该模块包括与训练相关的训练、评估、测试和预测 4 个函数。

训练的基本流程为:将 batch 输入模型,计算交叉熵,反向传递,使用 Adam 优化器优化梯度。其中学习率等参数在主程序中设置。每经过特定数量的 batch 后,使用验证集计算 loss,如果长时间 loss 未得到优化,就停止训练。

评估函数与训练类似,只是少了梯度的计算和优化,统计 loss,计算平均值。

测试函数是用于对训练完成后的模型进行测试的,与评估函数类似,但是加上了一些更加详细的统计方法,使用了 sklearn.metrics 模块中的相关方法,可以计算各类别的精确率、召回率、F1 score 以及混淆矩阵。

预测函数用于对单句文本进行预测,输入字符串后,加载保存的模型和词典,将字符串转为向量后输入模型,得到分类结果,处理后输出。该函数在主程序中使用。

3.2.3 主程序

主程序包括训练和预测两个功能,可以通过命令行指定参数和执行的操作。

4 主要模型

该项目主要模型为 TextCNN 模型,结构如下图所示:

TextCNN1

详细结构2如下:

TextCNN2

TextCNN 输入一个文本转换后得到的矩阵(向量),然后使用卷积核进行卷积操作,然后经过激活函数得到特征图,经过池化和全连接得到输出。

4.1 嵌入层(embedding layer)

嵌入层,这一层的主要作用是将输入的文本编码转化为词向量。对于数据集里的所有词,因为每个词都可以表征成一个向量,因此我们可以得到一个嵌入矩阵 $M$$M$ 中的每一行都是一个词向量可以使用与训练的词向量,也可以使用随机初始化的词向量。这里使用了随机初始化的词向量。

4.2 卷积层

这一层主要是通过卷积,提取不同的 n-gram 特征。输入的语句或者文本,通过嵌入层后,会转变成一个二维矩阵。这里选择了多个卷积核,其大小和数量在主程序中设置,大小默认为 $(2,3,4)$ ,数量默认为 128。

4.3 池化层

不同尺寸的卷积核得到的特征图大小也是不一样的,因此我们对每个特征图进行池化,使它们的维度相同。这里使用了 max pooling 作为池化函数,减少模型的参数,又保证了在不定长的卷积层的输出上获得一个定长的全连接层的输入。

4.4 全连接层

全连接层的作用就是分类器,TextCNN 模型使用了只有一层隐藏层的全连接网络。

为了防止过拟合,在全连接层前加入了 Dropout 随机失活,默认为 0.5。

5 模型评估及分析

该项目使用 THUCNews 中文文本数据集3中的部分数据作为数据集。选择了 10 个分类:体育、娱乐、家居、房产、教育、时尚、时政、游戏、科技、财经。

其中训练集每个分类各 1000 条文本,验证集每个分类各 100 条文本,训练集每个分类每个分类各 200 条文本。

5.1 测试效果

由于模型的嵌入层的词向量是随机初始化的,所以每次训练的结果有差别,但差异不大,这里选择了最后一次训练的模型进行测试。

使用如下指标作为评测标准:

准确率 $P_i=\frac{l_i}{m_i}$ ,召回率 $R_i=\frac{l_i}{n_i}$ ,F1-score $F_i=\frac{P_i×R_i×2}{P_i+R_i}$ ,宏平均准确率 $MacroP=\frac{1}{n}\sum_{i=1}^nP_i$ ,宏平均召回率 $MacroR=\frac{1}{n}\sum_{i=1}^nR_i$ ,宏平均 F1-score $MacroF=\frac{1}{n}\sum_{i=1}^nF_i$

其中 $l_i$ 为第 i 类分类正确的文本数, $m_i$ 为分类系统实际分类为 i 的文本数, $n_i$ 为专家分类为 i 的文本数。

结果如下:

********************TEST********************
Loss: 0.000828 Acc: 0.9695
[CLASSIFICATION REPORT]
              precision    recall  f1-score   support

         体育       1.00      0.97      0.98       200
         娱乐       0.98      0.94      0.96       200
         家居       0.91      0.96      0.94       200
         房产       0.97      0.99      0.98       200
         教育       0.95      0.96      0.96       200
         时尚       0.95      0.99      0.97       200
         时政       0.96      0.94      0.95       200
         游戏       0.97      0.97      0.97       200
         科技       1.00      0.98      0.99       200
         财经       0.99      0.97      0.98       200

    accuracy                           0.97      2000
   macro avg       0.97      0.97      0.97      2000
weighted avg       0.97      0.97      0.97      2000

[CONFUSION MATRIX]
[[194   0   1   0   1   0   4   0   0   0]
 [  0 188   7   0   1   1   2   1   0   0]
 [  0   1 192   1   2   4   0   0   0   0]
 [  0   0   1 199   0   0   0   0   0   0]
 [  0   0   1   1 192   3   1   2   0   0]
 [  0   0   0   0   1 199   0   0   0   0]
 [  0   0   5   0   5   0 189   0   0   1]
 [  0   2   2   0   0   1   0 195   0   0]
 [  0   0   1   0   0   1   0   2 196   0]
 [  0   0   0   4   0   0   1   0   0 195]]

test took 3.28s

可以看到,该模型的平均准确率为 96.95%,宏平均准确率、召回率、F1-score 均为 97%,各类别的准确率均高于 91%,召回率均高于 94%,F1-score 均高于 94%。可见模型效果较好。

5.2 预测效果

在网络上查找部分 2020 年 12 月发表的新闻作为输入,系统预测结果如下,其中第 2、4 行为系统输出:

  • 体育

    python .\main.py --predict
    plase enter a sentence in classes ['体育', '娱乐', '家居', '房产', '教育', '时尚', '时政', '游戏', '科技', '财经']
    广厦队之所以会去引进尼克-杨,原因是他们本赛季签下的小外援、前NBA球员威尔森-钱德勒打了3场比赛后就拒绝登场。据悉,他不愿意在封闭的环境里训练和比赛,也因此早早就拒绝登场。本赛季,钱德勒为广厦队出场3场,场均得到15.3分7.7个篮板。在他拒绝登场后,广厦队变得非常被动。
    [CLASS]体育
    
  • 娱乐

    python .\main.py --predict
    plase enter a sentence in classes ['体育', '娱乐', '家居', '房产', '教育', '时尚', '时政', '游戏', '科技', '财经']
    12月28日,以中国百强电影满意度调查为核心的奖项华鼎奖又来了,演员倪大红,导演康洪雷、姚晓峰等人作为评委出席,有陈宝国、佟瑞欣、张涵予 、郭宝昌、于冬、尉迟辅航等影视大佬到场,不出意外应该是今年最后一场演员大聚会了,一起来围观一下吧!
    [CLASS]娱乐
    
  • 家居

    python .\main.py --predict
    plase enter a sentence in classes ['体育', '娱乐', '家居', '房产', '教育', '时尚', '时政', '游戏', '科技', '财经']
    空间设计是一门很复杂的学问,大家不断探索寻找如何通过精心设计与装饰,让有限空间获得完美效果的方法。想要拥有高雅而又不失创意的家居环境,在装饰色彩、小物件的挑选上都需要下一番功夫。
    [CLASS]家居
    
  • 房产

    python .\main.py --predict
    plase enter a sentence in classes ['体育', '娱乐', '家居', '房产', '教育', '时尚', '时政', '游戏', '科技', '财经']
    仲量联行12月24日在广州举行“2020年广州房地产市场回顾及2021年展望”发布,仲量联行华南区董事总经理吴仲豪表示,优质零售物业市场方面,广州消费市场持续稳步复苏的态势,全市整体租赁需求稳步回升,2020全年净吸纳量超过46万平方米,位于一线城市之首。
    [CLASS]房产
    
  • 教育

    python .\main.py --predict
    plase enter a sentence in classes ['体育', '娱乐', '家居', '房产', '教育', '时尚', '时政', '游戏', '科技', '财经']
    记者调查发现,“考研神校”不再是对某所学校的评价,居于普通院校行列、考研升学率高、学生勤奋备考,具备类似特征的高校越来越常见。软科《2019本科毕业生深造率排名》,排名对象为中国1200多所本科层次的高校,榜单展示了本科毕业生总深造率前两百强的学校。山东科技大学、沈阳农业大学、滨州学院、上海科技大学等高校均在其列。换句话说,与一直在增加的考研大军相关联,“考研神校”也将越来越常见。(光明日报12月28日)
    [CLASS]教育
    

其他类别不再一一测试。

可以看到,预测精确度较高,符合上一节测试结果。

5.3 测试结果分析

通过上面的测试和预测,可以看出该模型的效果较好。但是在单句预测时,有时会出现预测错误的情况。

对于这种情况,有几种可能的原因:

  1. 数据集过时。由于该数据集是 2005~2011 年间的新闻文本,而新闻具有时效性。预测使用的新闻都是 2020 年的,其中有许多词语是新出现的,所以导致文本转换的向量与实际意义差别较大。
  2. 预测文本较短。预测选择的文本大多少于 300 字,而数据集中的文本均为整篇新闻,最长可达 6000 字。所以短文本提取的特征可能较少,导致分类错误。
  3. 训练过程可能存在过拟合问题。

针对这几个原因,可以通过更新数据集、使用长文本预测、调整模型结构和参数等方法优化。

6 项目结构和用法

项目结构如下:

.
├── data                  用于保存数据的文件夹
│   ├── classes.csv       类别列表
│   ├── saved_model.pt    保存的训练后的模型
│   ├── stop_words.txt    停用词表
│   ├── test.tsv          测试集
│   ├── train.tsv         训练集
│   ├── val.tsv           验证集
│   └── vocab.pkl         使用训练集和验证集构建的词典
├── main.py               主程序代码,包括训练和预测功能
├── module.py             TextCNN模型代码
├── preprocess.py         数据预处理代码
├── report.txt            测试报告
├── requirements.txt      项目依赖及版本
├── test.py               测试代码
└── train.py              训练代码

使用命令行运行程序,命令后可以加上参数,参数如下:

usage: main.py [-h] [--data_path DATA_PATH] [--device_num DEVICE_NUM]
               [--filter_sizes FILTER_SIZES] [--lr LR] [--batch_size BATCH_SIZE]      
               [--epochs EPOCHS] [--filter_num FILTER_NUM]
               [--embedding_dim EMBEDDING_DIM] [--dropout DROPOUT]
               [--show_steps SHOW_STEPS] [--stop_improvements STOP_IMPROVEMENTS]      
               [--predict]

optional arguments:
  -h, --help            show this help message and exit
  --data_path DATA_PATH
                        数据保存路径(默认'data')
  --device_num DEVICE_NUM
                        使用设备编号(默认0)
  --filter_sizes FILTER_SIZES
                        filter大小,默认(3,4,5)
  --lr LR               学习率(默认0.001)
  --batch_size BATCH_SIZE
                        batch大小(默认128)
  --epochs EPOCHS       epoch数(默认20)
  --filter_num FILTER_NUM
                        fliter数量(默认128)
  --embedding_dim EMBEDDING_DIM
                        embedding维度(默认128)
  --dropout DROPOUT     dropout(默认0.5)
  --show_steps SHOW_STEPS
                        每多少batch显示信息(默认1)
  --stop_improvements STOP_IMPROVEMENTS
                        多少batch内loss无提高时停止(默认300)
  --predict             预测模式

其中 --predict 表示预测模式,使用该参数不需要传值,使用该参数后进入预测模式,然后输入文本即可输出预测的分类。

test.py 用来预测,可以指定数据集对其进行测试,参数与主程序类似,可使用 python test.py -h 命令查看用法。

参考文献

Footnotes

  1. Kim Y. Convolutional neural networks for sentence classification[J]. arXiv preprint arXiv:1408.5882, 2014.

  2. Zhang Y , Wallace B . A Sensitivity Analysis of (and Practitioners' Guide to) Convolutional Neural Networks for Sentence Classification[J]. Computer Science, 2015.

  3. 孙茂松,李景阳,郭志芃,赵宇,郑亚斌,司宪策,刘知远. THUCTC:一个高效的中文文本分类工具包. 2016.

About

基于TextCNN和TorchText的中文文本分类

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages