Skip to content
Permalink
master
Switch branches/tags
Go to file
Latest commit 14c0b2c Dec 17, 2021 History
2 contributors

Users who have contributed to this file

@airaria @ymcui

English | 中文说明



GitHub PyPI GitHub release

TextBrewer是一个基于PyTorch的、为实现NLP中的知识蒸馏任务而设计的工具包, 融合并改进了NLP和CV中的多种知识蒸馏技术,提供便捷快速的知识蒸馏框架, 用于以较低的性能损失压缩神经网络模型的大小,提升模型的推理速度,减少内存占用。

可以通过ACL AnthologyarXiv pre-print查看我们的论文。

TextBrewer完整文档

新闻

Dec 17, 2021

Oct 24, 2021

Jul 8, 2021

  • 新增Transformers 4示例
    • 目前已有示例基于较早版本的Transformers,使用习惯与当前的Transformers不同。为了减少使用中的困惑与bugs,我们添加了基于Transformers 4的notebook示例,更易学习与使用。
    • 新示例位于examples/notebook_examples。详情参见蒸馏任务示例

Mar 1, 2021

  • BERT-EMD示例与自定义distiller

    • MNLI示例中增加了BERT-EMD算法的实现。BERT-EMD通过优化中间层之间的Earth Mvoer's Distance以自适应地调整教师与学生之间中间层匹配,而无需人工指定。
    • BERT-EMD以自定义distiller的方式(EMDDistiller)实现,可作为自定义distiller的参考。
  • MNLI示例更新

    • 更新了MNLI任务上的蒸馏示例,新代码不再依赖pytorch_pretrained_bert而使用transofrmers。
点击查看往期新闻

Nov 11, 2020

  • 版本更新至0.2.1:

    • 灵活性提升:支持为教师模型和学生模型输入各自独立的batch,不再要求教师模型和学生模型的输入相同。可用于词表不同的模型之间(例如从RoBERTa到BERT)的蒸馏。

    • 蒸馏加速:支持用户自定义传入教师模型的输出缓存,避免教师模型的重复前向计算,加速蒸馏过程。

      以上特性的详细说明可参见 Feed Different batches to Student and Teacher, Feed Cached Values

    • 增加了MultiTaskDistiller对中间层匹配损失的支持。

    • Tensorboard中记录更详细的损失函数(KD loss, hard label loss, matching losses...)。

    更新细节参见 releases

Aug 27, 2020

哈工大讯飞联合实验室在通用自然语言理解评测GLUE中荣登榜首,查看GLUE榜单新闻

Aug 24, 2020

  • 版本更新至0.2.0.1:
    • 修复了MultiTaskDistiller以及训练循环中的若干bug。

Jul 29, 2020

  • 版本更新至0.2.0:
    • 增加对分布式数据并行训练的支持:可通过在TrainingConfig中传入相应的local_rank以启用。详细设置参见TraningConfig的说明。
  • 增加了分布式数据并行训练的使用示例:中文命名实体识别任务上的ELECTRA-base模型的蒸馏,见examples/msra_ner_example

Jul 14, 2020

  • 版本更新至0.1.10:
    • 支持apex混合精度训练功能:可通过在TrainingConfig中设置fp16=True启用。详细设置参见TraningConfig的说明。
    • TrainingConfig中增加了data_parallel选项,使得数据并行与混合精度训练可同时启用。

Apr 26, 2020

  • 增加了中文NER任务(MSRA NER)上的实验结果。
  • 英文数据集上增加了蒸馏到T12-nano的实验结果。T12-nano的的结构与ELectra-small相似。
  • 更新了CoNLL-2003、CMRC 2018 和 DRCD 上的部分实验结果。

Apr 22, 2020

  • 版本更新至 0.1.9,增加了为蒸馏过程提速的cache功能,修复了若干bug。细节参见 releases
  • 增加了中文任务上从Electra-base蒸馏到Electra-small的实验结果。
  • TextBrewer被ACL 2020录用为demo paper,欢迎在您的工作中使用我们新的引用

Mar 17, 2020

Mar 11, 2020

  • 版本更新至 0.1.8(改进了TrainingConfig和distiller的train方法),细节参见 releases

Mar 2, 2020

  • 当前版本: 0.1.7, 初始版本。

目录

章节 内容
简介 TextBrewer简介
安装 安装方法介绍
工作流程 TextBrewer整体工作流程
快速开始 举例展示TextBrewer用法:BERT-base蒸馏至3层BERT
蒸馏效果 中文、英文典型数据集上的蒸馏效果展示
核心概念 TextBrewer中的核心概念介绍
FAQ 常见问题解答
引用 TextBrewer参考引用
已知问题 尚未解决的问题
关注我们 -

简介

TextBrewer 为NLP中的知识蒸馏任务设计,融合了多种知识蒸馏技术,提供方便快捷的知识蒸馏框架。

主要特点:

  • 模型无关:适用于多种模型结构(主要面向Transfomer结构)
  • 方便灵活:可自由组合多种蒸馏方法;可方便增加自定义损失等模块
  • 非侵入式:无需对教师与学生模型本身结构进行修改
  • 支持典型的NLP任务:文本分类、阅读理解、序列标注等

TextBrewer目前支持的知识蒸馏技术有:

  • 软标签与硬标签混合训练
  • 动态损失权重调整与蒸馏温度调整
  • 多种蒸馏损失函数: hidden states MSE, attention-based loss, neuron selectivity transfer, ...
  • 任意构建中间层特征匹配方案
  • 多教师知识蒸馏
  • ...

TextBrewer的主要功能与模块分为3块:

  1. Distillers:进行蒸馏的核心部件,不同的distiller提供不同的蒸馏模式。目前包含GeneralDistiller, MultiTeacherDistiller, MultiTaskDistiller等
  2. Configurations and Presets:训练与蒸馏方法的配置,并提供预定义的蒸馏策略以及多种知识蒸馏损失函数
  3. Utilities:模型参数分析显示等辅助工具

用户需要准备:

  1. 已训练好的教师模型, 待蒸馏的学生模型
  2. 训练数据与必要的实验配置, 即可开始蒸馏

在多个典型NLP任务上,TextBrewer都能取得较好的压缩效果。相关实验见蒸馏效果

详细的API可参见 完整文档

TextBrewer结构

安装

安装要求

  • Python >= 3.6
  • PyTorch >= 1.1.0
  • TensorboardX or Tensorboard
  • NumPy
  • tqdm
  • Transformers >= 2.0 (可选, Transformer相关示例需要用到)
  • Apex == 0.1.0 (可选,用于混合精度训练)

安装方式

  • 从PyPI自动下载安装包安装:
pip install textbrewer
  • 从源码文件夹安装:
git clone https://github.com/airaria/TextBrewer.git
pip install ./textbrewer

工作流程

  • Stage 1 : 蒸馏之前的准备工作:

    1. 训练教师模型
    2. 定义与初始化学生模型(随机初始化,或载入预训练权重)
    3. 构造蒸馏用数据集的dataloader,训练学生模型用的optimizer和learning rate scheduler
  • Stage 2 : 使用TextBrewer蒸馏:

    1. 构造训练配置(TrainingConfig)和蒸馏配置(DistillationConfig),初始化distiller
    2. 定义adaptorcallback ,分别用于适配模型输入输出和训练过程中的回调
    3. 调用distillertrain方法开始蒸馏

快速开始

以蒸馏BERT-base到3层BERT为例展示TextBrewer用法。

在开始蒸馏之前准备:

  • 训练好的教师模型teacher_model (BERT-base),待训练学生模型student_model (3-layer BERT)
  • 数据集dataloader,优化器optimizer,学习率调节器类或者构造函数scheduler_class 和构造用的参数字典 scheduler_args

使用TextBrewer蒸馏:

import textbrewer
from textbrewer import GeneralDistiller
from textbrewer import TrainingConfig, DistillationConfig

# 展示模型参数量的统计
print("\nteacher_model's parametrers:")
result, _ = textbrewer.utils.display_parameters(teacher_model,max_level=3)
print (result)

print("student_model's parametrers:")
result, _ = textbrewer.utils.display_parameters(student_model,max_level=3)
print (result)

# 定义adaptor用于解释模型的输出
def simple_adaptor(batch, model_outputs):
    # model输出的第二、三个元素分别是logits和hidden states
    return {'logits': model_outputs[1], 'hidden': model_outputs[2]}

# 蒸馏与训练配置
# 匹配教师和学生的embedding层;同时匹配教师的第8层和学生的第2层
distill_config = DistillationConfig(
    intermediate_matches=[    
     {'layer_T':0, 'layer_S':0, 'feature':'hidden', 'loss': 'hidden_mse','weight' : 1},
     {'layer_T':8, 'layer_S':2, 'feature':'hidden', 'loss': 'hidden_mse','weight' : 1}])
train_config = TrainingConfig()

#初始化distiller
distiller = GeneralDistiller(
    train_config=train_config, distill_config = distill_config,
    model_T = teacher_model, model_S = student_model, 
    adaptor_T = simple_adaptor, adaptor_S = simple_adaptor)

# 开始蒸馏
with distiller:
    distiller.train(optimizer, dataloader, num_epochs=1, scheduler_class=scheduler_class, scheduler_args = scheduler_args, callback=None)

蒸馏任务示例

蒸馏效果

我们在多个中英文文本分类、阅读理解、序列标注数据集上进行了蒸馏实验。实验的配置和效果如下。

模型

我们测试了不同的学生模型,为了与已有公开结果相比较,除了BiGRU都是和BERT一样的多层Transformer结构。模型的参数如下表所示。需要注意的是,参数量的统计包括了embedding层,但不包括最终适配各个任务的输出层。

英文模型

Model #Layers Hidden size Feed-forward size #Params Relative size
BERT-base-cased (教师) 12 768 3072 108M 100%
T6 (学生) 6 768 3072 65M 60%
T3 (学生) 3 768 3072 44M 41%
T3-small (学生) 3 384 1536 17M 16%
T4-Tiny (学生) 4 312 1200 14M 13%
T12-nano (学生) 12 256 1024 17M 16%
BiGRU (学生) - 768 - 31M 29%

中文模型

Model #Layers Hidden size Feed-forward size #Params Relative size
RoBERTa-wwm-ext (教师) 12 768 3072 102M 100%
Electra-base (教师) 12 768 3072 102M 100%
T3 (学生) 3 768 3072 38M 37%
T3-small (学生) 3 384 1536 14M 14%
T4-Tiny (学生) 4 312 1200 11M 11%
Electra-small (学生) 12 256 1024 12M 12%

蒸馏配置

distill_config = DistillationConfig(temperature = 8, intermediate_matches = matches)
# 其他参数为默认值

不同的模型用的matches我们采用了以下配置:

Model matches
BiGRU None
T6 L6_hidden_mse + L6_hidden_smmd
T3 L3_hidden_mse + L3_hidden_smmd
T3-small L3n_hidden_mse + L3_hidden_smmd
T4-Tiny L4t_hidden_mse + L4_hidden_smmd
T12-nano small_hidden_mse + small_hidden_smmd
Electra-small small_hidden_mse + small_hidden_smmd

各种matches的定义在examples/matches/matches.py中。均使用GeneralDistiller进行蒸馏。

训练配置

蒸馏用的学习率 lr=1e-4(除非特殊说明)。训练30~60轮。

英文实验结果

在英文实验中,我们使用了如下三个典型数据集。

Dataset Task type Metrics #Train #Dev Note
MNLI 文本分类 m/mm Acc 393K 20K 句对三分类任务
SQuAD 1.1 阅读理解 EM/F1 88K 11K 篇章片段抽取型阅读理解
CoNLL-2003 序列标注 F1 23K 6K 命名实体识别任务

我们在下面两表中列出了DistilBERT, BERT-PKD, BERT-of-Theseus, TinyBERT 等公开的蒸馏结果,并与我们的结果做对比。

Public results:

Model (public) MNLI SQuAD CoNLL-2003
DistilBERT (T6) 81.6 / 81.1 78.1 / 86.2 -
BERT6-PKD (T6) 81.5 / 81.0 77.1 / 85.3 -
BERT-of-Theseus (T6) 82.4/ 82.1 - -
BERT3-PKD (T3) 76.7 / 76.3 - -
TinyBERT (T4-tiny) 82.8 / 82.9 72.7 / 82.1 -

Our results:

Model (ours) MNLI SQuAD CoNLL-2003
BERT-base-cased (教师) 83.7 / 84.0 81.5 / 88.6 91.1
BiGRU - - 85.3
T6 83.5 / 84.0 80.8 / 88.1 90.7
T3 81.8 / 82.7 76.4 / 84.9 87.5
T3-small 81.3 / 81.7 72.3 / 81.4 78.6
T4-tiny 82.0 / 82.6 75.2 / 84.0 89.1
T12-nano 83.2 / 83.9 79.0 / 86.6 89.6

说明:

  1. 公开模型的名称后括号内是其等价的模型结构
  2. 蒸馏到T4-tiny的实验中,SQuAD任务上使用了NewsQA作为增强数据;CoNLL-2003上使用了HotpotQA的篇章作为增强数据
  3. 蒸馏到T12-nano的实验中,CoNLL-2003上使用了HotpotQA的篇章作为增强数据

中文实验结果

在中文实验中,我们使用了如下典型数据集。

Dataset Task type Metrics #Train #Dev Note
XNLI 文本分类 Acc 393K 2.5K MNLI的中文翻译版本,3分类任务
LCQMC 文本分类 Acc 239K 8.8K 句对二分类任务,判断两个句子的语义是否相同
CMRC 2018 阅读理解 EM/F1 10K 3.4K 篇章片段抽取型阅读理解
DRCD 阅读理解 EM/F1 27K 3.5K 繁体中文篇章片段抽取型阅读理解
MSRA NER 序列标注 F1 45K 3.4K (测试集) 中文命名实体识别

实验结果如下表所示。

Model XNLI LCQMC CMRC 2018 DRCD
RoBERTa-wwm-ext (教师) 79.9 89.4 68.8 / 86.4 86.5 / 92.5
T3 78.4 89.0 66.4 / 84.2 78.2 / 86.4
T3-small 76.0 88.1 58.0 / 79.3 75.8 / 84.8
T4-tiny 76.2 88.4 61.8 / 81.8 77.3 / 86.1
Model XNLI LCQMC CMRC 2018 DRCD MSRA NER
Electra-base (教师) 77.8 89.8 65.6 / 84.7 86.9 / 92.3 95.14
Electra-small 77.7 89.3 66.5 / 84.9 85.5 / 91.3 93.48

说明:

  1. 以RoBERTa-wwm-ext为教师模型蒸馏CMRC 2018和DRCD时,不采用学习率衰减
  2. CMRC 2018和DRCD两个任务上蒸馏时他们互作为增强数据
  3. Electra-base的教师模型训练设置参考自Chinese-ELECTRA
  4. Electra-small学生模型采用预训练权重初始化

核心概念

Configurations

  • TrainingConfigDistillationConfig:训练和蒸馏相关的配置。

Distillers

Distiller负责执行实际的蒸馏过程。目前实现了以下的distillers:

  • BasicDistiller: 提供单模型单任务蒸馏方式。可用作测试或简单实验。
  • GeneralDistiller (常用): 提供单模型单任务蒸馏方式,并且支持中间层特征匹配,一般情况下推荐使用
  • MultiTeacherDistiller: 多教师蒸馏。将多个(同任务)教师模型蒸馏到一个学生模型上。暂不支持中间层特征匹配
  • MultiTaskDistiller:多任务蒸馏。将多个(不同任务)单任务教师模型蒸馏到一个多任务学生模型。
  • BasicTrainer:用于单个模型的有监督训练,而非蒸馏。可用于训练教师模型

用户定义函数

蒸馏实验中,有两个组件需要由用户提供,分别是callbackadaptor :

Callback

回调函数。在每个checkpoint,保存模型后会被distiller调用,并传入当前模型。可以借由回调函数在每个checkpoint评测模型效果。

Adaptor

将模型的输入和输出转换为指定的格式,向distiller解释模型的输入和输出,以便distiller根据不同的策略进行不同的计算。在每个训练步,batch和模型的输出model_outputs会作为参数传递给adaptoradaptor负责重新组织这些数据,返回一个字典。

更多细节可参见完整文档中的说明。

FAQ

Q: 学生模型该如何初始化?

A: 知识蒸馏本质上是“老师教学生”的过程。在初始化学生模型时,可以采用随机初始化的形式(即完全不包含任何先验知识),也可以载入已训练好的模型权重。例如,从BERT-base模型蒸馏到3层BERT时,可以预先载入RBT3模型权重(中文任务)或BERT的前三层权重(英文任务),然后进一步进行蒸馏,避免了蒸馏过程的“冷启动”问题。我们建议用户在使用时尽量采用已预训练过的学生模型,以充分利用大规模数据预训练所带来的优势。

Q: 如何设置蒸馏的训练参数以达到一个较好的效果?

A: 知识蒸馏的比有标签数据上的训练需要更多的训练轮数与更大的学习率。比如,BERT-base上训练SQuAD一般以lr=3e-5训练3轮左右即可达到较好的效果;而蒸馏时需要以lr=1e-4训练30~50轮。当然具体到各个任务上肯定还有区别,我们的建议仅是基于我们的经验得出的,仅供参考

Q: 我的教师模型和学生模型的输入不同(比如词表不同导致input_ids不兼容),该如何进行蒸馏?

A: 需要分别为教师模型和学生模型提供不同的batch,参见完整文档中的 Feed Different batches to Student and Teacher, Feed Cached Values 章节。

Q: 我缓存了教师模型的输出,它们可以用于加速蒸馏吗?

A: 可以, 参见完整文档中的 Feed Different batches to Student and Teacher, Feed Cached Values 章节。

已知问题

  • 尚不支持DataParallel以外的多卡训练策略。
  • 尚不支持多标签分类任务。

引用

如果TextBrewer工具包对你的研究工作有所帮助,请在文献中引用我们的论文

@InProceedings{textbrewer-acl2020-demo,
    title = "{T}ext{B}rewer: {A}n {O}pen-{S}ource {K}nowledge {D}istillation {T}oolkit for {N}atural {L}anguage {P}rocessing",
    author = "Yang, Ziqing and Cui, Yiming and Chen, Zhipeng and Che, Wanxiang and Liu, Ting and Wang, Shijin and Hu, Guoping",
    booktitle = "Proceedings of the 58th Annual Meeting of the Association for Computational Linguistics: System Demonstrations",
    year = "2020",
    publisher = "Association for Computational Linguistics",
    url = "https://www.aclweb.org/anthology/2020.acl-demos.2",
    pages = "9--16",
}

关注我们

欢迎关注哈工大讯飞联合实验室官方微信公众号,了解最新的技术动态。