diff --git a/examples/model_compression/PP-MiniLM/README.md b/examples/model_compression/PP-MiniLM/README.md new file mode 100644 index 000000000000..16dd0a48f3b0 --- /dev/null +++ b/examples/model_compression/PP-MiniLM/README.md @@ -0,0 +1,389 @@ + **目录** + +* [PP-MiniLM 中文小模型](#PP-MiniLM中文小模型) + * [导入 PP-MiniLM](#导入PP-MiniLM) + * [在下游任务上使用 PP-MiniLM](#在下游任务上使用PP-MiniLM) + * [数据介绍](#数据介绍) + * [环境依赖](#环境依赖) + * [微调](#微调) + * [运行方式](#运行方式) + * [微调后模型精度](#微调后模型精度) + * [导出微调后模型](#导出微调后模型) + * [裁剪](#裁剪) + * [原理简介](#原理简介) + * [运行方式](#运行方式) + * [裁剪后模型精度](#裁剪后模型精度) + * [导出裁剪后的模型](#导出裁剪后的模型) + * [量化](#量化) + * [原理简介](#原理简介) + * [运行方式](#运行方式) + * [量化后模型精度](#量化后模型精度) + * [预测](#预测) + * [环境要求](#环境要求) + * [运行方式](#运行方式) + * [性能测试](#性能测试) + * [参考文献](#参考文献) + + + +# PP-MiniLM 中文小模型 +[PaddleNLP](https://github.com/PaddlePaddle/PaddleNLP) 联合 [PaddleSlim](https://github.com/PaddlePaddle/PaddleSlim) 通过模型蒸馏、剪裁、量化等级联模型压缩技术发布中文特色小模型 PP-MiniLM(6L768H) 及压缩方案,保证模型精度的同时模型推理速度达 BERT(12L768H) 的 4.2 倍,参数量相比减少 52%,模型精度在中文语言理解评测基准 CLUE 高 0.32。 + +PP-MiniLM 压缩方案以面向预训练模型的任务无关知识蒸馏(Task-agnostic Distillation)技术、裁剪(Pruning)技术、量化(Quantization)技术为核心,使得 PP-MiniLM **又快**、**又准**、**又小**。 + +1. **推理速度快**: 依托 PaddleSlim 的裁剪、量化技术对 PP-MiniLM 小模型进行压缩、加速, 使得 PP-MiniLM 量化后模型 GPU 推理速度相比 BERT base 加速比高达 4.2; + +2. **精度高**: 我们以 [MiniLMv2](https://arxiv.org/abs/2012.15828) 提出的 Multi-Head Self-Attention Relation Distillation 技术为基础,通过引入样本间关系知识蒸馏做了进一步算法优化,6 层 PP-MiniLM 模型在 CLUE 数据集上比 12 层 `bert-base-chinese` 高 0.32%,比同等规模的 TinyBERT、UER-py RoBERTa 分别高 2.09%、1.91%; + +3. **参数规模小**:依托 Task-agnostic Distillation 技术和 PaddleSlim 裁剪技术,模型参数量相比 BERT 减少 52%。 + +**整体效果** + +| Model | #Params | #FLOPs | Speedup | AFQMC | TNEWS | IFLYTEK | CMNLI | OCNLI | WSC | CSL | CLUE 平均值 | +| ----------------------- | ------- | ------ | ------- | ----- | ----- | ------- | ----- | ----- | ----- | ----- | ---------- | +| Bertbase | 102.3M | 10.87B | 1.00x | 74.14 | 56.81 | 61.10 | 81.19 | 74.85 | 79.93 | 81.47 | 72.78 | +| TinyBERT6 | 59.7M | 5.44B | 1.66x | 72.59 | 55.70 | 57.64 | 79.57 | 73.97 | 77.63 | 80.00 | 71.01 | +| UER-py RoBERTa L6- H768 | 59.7M | 5.44B | 1.66x | 69.74 | 66.36 | 59.95 | 77.00 | 71.39 | 71.05 | 82.83 | 71.19 | +| RBT6, Chinese | 59.7M | 5.44B | 1.66x | 73.93 | 56.63 | 59.79 | 79.28 | 73.12 | 77.30 | 80.80 | 71.55 | +| ERNIE-Tiny | 90.7M | 4.83B | 1.89x | 70.67 | 55.60 | 59.91 | 75.74 | 71.36 | 67.11 | 76.70 | 68.16 | +| PP-MiniLM 6L-768H | 59.7M | 5.44B | 1.66x | 74.14 | 57.43 | 61.75 | 81.01 | 76.17 | 86.18 | 77.47 | 73.45 | +| PP-MiniLM裁剪后 | 49.1M | 4.08B | 2.00x | 73.91 | 57.44 | 61.64 | 81.10 | 75.59 | 85.86 | 77.97 | 73.36 | +| PP-MiniLM量化后 | 49.2M | 4.08B | 4.15x | 74.00 | 57.37 | 61.33 | 81.09 | 75.56 | 85.85 | 76.53 | 73.10 | + + +**NOTE:** 说明: + +1.上表所有模型的精度测试均是基于下方超参数范围进行的 Grid Search 超参寻优。在每个配置下训练时,每隔 100 个 steps 在验证集上评估一次,取验证集上最佳准确率作为当前超参数配置下的准确率; +- batch sizes: 16, 32, 64; +- learning rates: 3e-5, 5e-5, 1e-4 +2.量化后比量化前模型参数量多了 0.1M 是因为保存了 scale 值。 + +**方案流程** + +

+
+方案流程图 +

+ +如上流程图所示,完整的中文小模型方案分为:导入 PP-MiniLM 中文预训练小模型、下游任务微调、裁剪、离线量化、预测部署五大步。下面会对这里的每一个步骤进行介绍。除了下游任务微调步骤,其余步骤均可以省略,但我们建议保留下面的每一个步骤。 + +以下是本范例模型的简要目录结构及说明: + +```shell +. +├── general_distill # 任务无关知识蒸馏目录 +│ └── general_distill.py # 任务无关知识蒸馏脚本 +│ └── run.sh # 任务无关知识蒸馏启动脚本 +│ └── README.md # 任务无关知识蒸馏文档 +├── finetuning # 下游任务训练目录 +│ └── run_clue.py # CLUE 上的微调脚本 +│ └── run_clue.sh # CLUE 上的微调启动脚本 +│ └── run_one_search.sh # 单数据集下精调脚本 +│ └── run_all_search.sh # CLUE数据集下精调脚本 +│ └── export_model.sh # 导出 fine-tuned 部署模型脚本 +├── pruning # 裁剪、蒸馏目录 +│ └── prune.py # 裁剪、蒸馏脚本 +│ └── prune.sh # 裁剪、蒸馏启动脚本 +│ └── export_model.py # 导出裁剪训练得到的子模型(动、静态图模型) +├── quantization # 离线量化目录 +│ └── quant_post.py # 离线量化脚本 +│ └── quant.sh # 离线量化启动脚本 +├── inference # 预测目录 +│ └── infer.py # 预测脚本 +│ └── infer_all.sh # 批量预测量化模型启动脚本 +│ └── infer_perf.py # 量化模型性能测试脚本 +│ └── infer_perf.sh # 量化模型性能测试启动脚本 +├── data.py # 数据处理脚本 +├── pp-minilm.png # PP-MiniLM 方案流程图 +└── README.md # 文档,本文件 + +``` + + + +## 导入 PP-MiniLM + +PP-MiniLM 是使用任务无关蒸馏方法,以 `roberta-wwm-ext-large` 做教师模型蒸馏产出的 6 层 ERNIE 模型(即包含 6 层 Transformer Encoder Layer、Hidden Size 为 768 的中文预训练小模型),在 CLUE 上 7 个分类任务上的模型精度超过 BERTbase、TinyBERT6、UER-py RoBERTa L6-H768、RBT6。 + +可以这样导入 PP-MiniLM: + +```python + +from paddlenlp.transformers import ErnieModel, ErnieForSequenceClassification + +model = ErnieModel.from_pretrained('ppminilm-6l-768h') +model = ErnieForSequenceClassification.from_pretrained('ppminilm-6l-768h') # 用于分类任务 +``` + +PP-MiniLM 是一个 6 层的预训练模型,使用 `from_pretrained`导入 PP-MiniLM 之后,就可以在自己的数据集上进行 fine-tuning。接下来会介绍如何用下游任务数据在导入的 PP-MiniLM 上进行微调、进一步压缩及推理部署。 + +**NOTE:** 如果对 PP-MiniLM 的训练过程感兴趣,可以查看[任务无关蒸馏文档](general_distill/README.md)了解相关细节。 + + + +## 在下游任务上使用 PP-MiniLM + +PP-MiniLM 预训练小模型在 CLUE 中的 7 个分类数据集的平均精度上比 12 层 `bert-base-chinese` 高 0.32%,比同等规模的 TinyBERT、UER-py RoBERTa 分别高 2.09%、1.91%,因此我们推荐将 PP-MiniLM 运用在中文下游任务上。当然,如果想对已有模型进一步压缩,也可以参考这里的压缩方案,因为压缩方案是通用的。 + +本案例中会以 CLUE 中 7 个分类数据集为例介绍如何在下游任务上使用 PP-MiniLM。首先用 CLUE 中的数据集对预训练小模型 PP-MiniLM 进行微调,然后提供了一套压缩方案,即借助 [PaddleSlim](https://github.com/PaddlePaddle/PaddleSlim) 进行裁剪和量化,进一步对模型规模进行压缩,最终使用基于 TensorRT 的 [Paddle Inference](https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/guides/05_inference_deployment/inference/inference_cn.html) 预测库对量化后的模型进行预测部署。裁剪、量化前,6 层 PP-MiniLM 的推理速度达`bert-base-chinese`的 1.7 倍,在下游任务上压缩完成后,模型推理速度高达`bert-base-chinese`的 4.2 倍。 + + + +### 数据介绍 + +本案例中下游任务使用的数据是 CLUE 中 7 个分类数据集,包括 AFQMC、TNEWS、IFLYTEK、OCNLI、CMNLI、CSL、CLUEWSC2020。在 Linux 环境下,运行 `run_clue.py` 这个 fine-tuning 脚本会将该数据集自动下载到`~/.paddlenlp/datasets/Clue/`目录下。 + + + +### 环境依赖 + +压缩方案依赖 [PaddleSlim](https://github.com/PaddlePaddle/PaddleSlim) 提供的裁剪、量化功能,因此需要安装 paddleslim。PaddleSlim 是个专注于深度学习模型压缩的工具库,提供剪裁、量化、蒸馏、和模型结构搜索等模型压缩策略,帮助用户快速实现模型的小型化。 + +```shell +pip install -U paddleslim -i https://pypi.org/simple +``` + + + +### 微调 + +基于如下超参范围对 PP-MiniLM 在各个任务上进行 Grid Search 超参寻优: + +- batch sizes: 16, 32, 64 +- learning rates: 3e-5, 5e-5, 1e-4 + + + +#### 运行方式 + +```shell +cd finetuning +mkdir ppminilm-6l-768h +sh run_all_search.sh ppminilm-6l-768h +``` + +如果只是在单个数据集上用特定 `batch_size`、`learning_rate` 微调,可以使用如下命令: + +``` +sh run_clue.sh CLUEWSC2020 1e-4 32 50 128 0 ppminilm-6l-768h +``` + +其中每个参数依次表示:CLUE 中的任务名称、学习率、batch size、epoch 数、最大序列长度、gpu id、模型名称(模型保存目录)。 + + + +#### 微调后模型精度 + +经过超参寻优后,我们可以得到在 CLUE 每个任务上验证集上有最高准确率的模型,CLUE 上各个任务上的最高准确率如下表: + +| AFQMC | TNEWS | IFLYTEK | CMNLI | OCNLI | WSC | CSL | CLUE 平均值 | +| ----- | ----- | ------- | ----- | ----- | ----- | ----- | ----------- | +| 74.14 | 57.43 | 61.75 | 81.01 | 76.17 | 86.18 | 77.47 | 73.45 | + + +超参寻优完成后,保存下每个数据集下有最高准确率的模型,以及其对应的超参数,因裁剪、量化等后续步骤需要用到最好的模型和超参数。 + + + +#### 导出微调后模型 + +如果模型经过了超参寻优,在这一步我们可以在每个任务上选择表现最好的模型进行导出。 + +假设待导出的模型的地址为 `ppminilm-6l-768h/models/CLUEWSC2020/1e-4_32`,可以运行下方命令将动态图模型导出为可用于部署的静态图模型: + +```shell +python export_model.py --model_type ernie --model_path ppminilm-6l-768h/models/CLUEWSC2020/1e-4_32 --output_path fine_tuned_infer_model/float +cd .. +``` + + + +### 裁剪 + +这一步主要使用 PaddleSlim 对下游任务上的模型宽度进行裁剪,进一步压缩模型的大小。 + +该过程会以上一步的模型(即 fine-tuning 后得到的最好模型)当作教师模型,蒸馏宽度为 3/4 的学生模型。经过我们的实验,在 6L768H 条件下,模型宽度压缩为原来的 3/4,精度几乎无损(-0.09)。 + + + +#### 原理简介 + +本方案采取的裁剪方法参考了 [DynaBERT-Dynamic BERT with Adaptive Width and Depth](https://arxiv.org/pdf/2004.04037) 中的策略。首先对预训练模型和 Head 进行重要性排序,保证更重要的 Head 不容易被裁掉,然后用原模型作为蒸馏过程中的教师模型,宽度更小的(本案例是 3/4 宽度)模型作为学生模型,蒸馏得到的学生模型就是我们裁剪得到的模型。 + + + +#### 运行方式 + +假设需要对上一步 fine-tuned 模型 `../finetuning/ppminilm-6l-768h/models/CLUEWSC2020/1e-4_32` 进行裁剪,其中 `learning_rate`、`batch_size` 可以继续使用 fine-tuning 时的参数,这里执行的是宽度 `0.75` 的裁剪,可以使用如下命令运行: + +```shell +cd pruning +export FT_MODELS=../finetuning/ppminilm-6l-768h/models/CLUEWSC2020/1e-4_32 + +sh prune.sh CLUEWSC2020 5e-5 16 50 128 4 ${FT_MODELS} 0.75 +``` +其中每个参数依次表示:CLUE 中的任务名称、学习率、batch size、epoch 数、最大序列长度、gpu id、学生模型的地址、裁剪后宽度比例列表。执行完成后,模型保存的路径位于 `pruned_models/CLUEWSC2020/0.75/best_model/`。 + + + +#### 裁剪后模型精度 + +经过裁剪后,CLUE 上各个任务上的精度如下表所示。相比起裁剪前,CLUE 数据集上平均值下降 0.09。模型的参数量由 59.7M 下降到 49.1M。 + +| AFQMC | TNEWS | IFLYTEK | CMNLI | OCNLI | WSC | CSL | CLUE 平均值 | +| ----- | ----- | ------- | ----- | ----- | ----- | ----- | ----------- | +| 73.91 | 57.44 | 61.64 | 81.10 | 75.59 | 85.86 | 77.97 | 73.36 | + + + + +#### 导出裁剪后的模型 + +这一步可以同时导出经过裁剪后特定宽度下模型的动、静态图的模型和参数等文件。 + +以 CLUEWSC2020 数据集为例,导出模型: + +```shell + +export MODEL_PATH=pruned_models +export TASK_NAME=CLUEWSC2020 +sh export.sh ${MODEL_PATH} ${TASK_NAME} +``` + +或者可以批量导出 CLUE 各个任务上的模型: + +```shell + +sh export_all.sh +cd .. +``` + +导出后的模型位于 `${MODEL_PATH}/${TASK_NAME}/0.75/sub_static/float`。 + + + +### 量化 + +```shell +cd quantization +``` + + + +#### 原理简介 + +这里的量化采用的是静态离线量化方法,即不需要训练,只使用少量校准数据计算量化因子,就可快速得到量化模型。这一步需要有训练好的预测(静态图)模型。因此,需要对前序步骤产出的模型进行导出(参考上方导出模型的运行方式)。 + +量化我们可以借助 PaddleSlim 提供的离线量化 API `paddleslim.quant.quant_post_static` 实现,我们这一步使用了 `mse`、`avg`、`abs_max`、`hist` 四种策略,并使用 4、8 两种校准集数量,对 `matmul/matmul_v2` 算子进行`channel_wise_abs_max` 类型的量化。 + + + +#### 运行方式 + +运行如下脚本可以得到静态离线量化后的模型: + +```shell +export MODEL_DIR=../pruning/pruned_models/ +python quant_post.py --task_name $TASK_NAME --input_dir ${MODEL_DIR}/${TASK_NAME}/0.75/sub_static +``` + +可以批量对所有数据集下的 FP32 模型进行量化: + +```shell +sh quant_all.sh +cd .. +``` + + + +#### 量化后模型精度 + +经过量化后,CLUE 上各个任务上的精度如下表,比上一步(裁剪后)平均精度下降了 0.26: + +| AFQMC | TNEWS | IFLYTEK | CMNLI | OCNLI | WSC | CSL | CLUE 平均值 | +| ----- | ----- | ------- | ----- | ----- | ----- | ----- | ----------- | +| 74.00 | 57.37 | 61.33 | 81.09 | 75.56 | 85.85 | 76.53 | 73.10 | + +最后,值得注意的是,PP-MiniLM 是基于 `roberta-wwm-ext-large` 做教师模型蒸馏得到的学生模型,如果你有更好的 24 层中文预训练模型,可以基于[任务无关蒸馏文档](general_distill/README.md)中介绍的蒸馏过程,训练出一个比 PP-MiniLM 精度更高,在下游任务上表现更好的 6 层小模型。 + + + +### 预测 + +预测部署借助 PaddlePaddle 安装包中自带的 [Paddle Inference](https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/guides/05_inference_deployment/inference/inference_cn.html) 进行预测。 + + + +#### 环境要求 + +这一步依赖安装有预测库的 PaddlePaddle 2.2.1。可以在 [PaddlePaddle 官网](https://paddleinference.paddlepaddle.org.cn/user_guides/download_lib.html#python) 根据机器环境选择合适的 Python 预测库进行安装。 + +想要得到更明显的加速效果,推荐在 NVIDA Tensor Core GPU(如 T4、A10、A100)上进行测试,本案例基于 T4 测试。若在V系列GPU卡上测试,由于其不支持 Int8 Tensor Core,加速效果将达不到本文档表格中的效果。 + +本案例是在 NVIDIA Tesla T4 单卡上,使用 cuda 11.1、cudnn 8.1、TensorRT 7.2 进行预测。 + + + +#### 运行方式 + +这里使用了动态 shape 功能,因此需要设置获取 shape 的范围。Paddle Inference 提供了相应的接口,即首先通过离线输入数据来统计出所有临时 tensor 的 shape 范围,TRT 子图的 tensor 输入 shape 范围可直接根据上一步 tune 出来的结果来设置,即可完成自动 shape 范围设置。统计完成后,只需设置统计结果路径,即可启用 `tuned_dynamic_shape` 功能。在本案例中,只需要先设置 `--collect_shape` 参数,运行 `infer.py`,然后再取消传入这个参数,再次运行 `infer.py`。例如: + +INT8 预测运行脚本: + +```shell + +cd inference +export task=tnews +export algo=mse +export bs=4 +python infer.py --task_name ${task} --model_path ../quantization/${task}_quant_models/${algo}${bs}/int8 --int8 --use_trt --collect_shape # 生成shape range info文件 +python infer.py --task_name ${task} --model_path ../quantization/${task}_quant_models/${algo}${bs}/int8 --int8 --use_trt # load shape range info文件进行预测 +``` +如果想要批量对量化模型进行预测并输出不同量化策略产出模型的精度,可以使用如下的脚本批量预测: + +```shell +sh infer_all.sh +``` + +FP32 预测运行脚本: + +```shell +python infer.py --task_name ${task} --model_path $MODEL_PATH --use_trt --collect_shape +python infer.py --task_name ${task} --model_path $MODEL_PATH --use_trt +``` + + + +#### 性能测试 + +测试性能环境同上。本案例测试采用的是 CLUE TNEWS 数据集下量化方法为 `mse`、校准集数量为 4 得到的量化模型,在 TNEWS 的验证集上统计 5 次端到端预测的总耗时(前 20 个 steps 作为 warmup steps 跳过)并求平均。下表后三行分别是微调后的模型、裁剪后的模型、量化后模型的总耗时情况,加速倍数列是较 `bert-base-chinese` 的推理加速倍数。 + +启动性能测试需要对 `infer.py` 脚本传入参数 `--perf`,运行性能测试脚本可以得到 PP-MiniLM、PP-MiniLM 裁剪后、PP-MiniLM 量化后模型预测的耗时: + +```shell + +sh infer_perf.sh +cd .. +``` + +取 5 个非 `--collect_shape` 阶段打印出的时长取平均,可以发现借助 PaddleSlim 裁剪、量化后的模型是原 BERTbase模型推理速度的 4.2 倍,其中裁剪后的模型是 BERTbase推理速度的 2.0 倍。 + +| | 平均耗时(s) | 加速比 | +| ------------------- | ----------- | ------ | +| BERTbase | 21.04 | - | +| PP-MiniLM | 12.64 | 1.66x | +| PP-MiniLM裁剪后 | 10.54 | 2.00x | +| PP-MiniLM量化后 | 5.07 | 4.15x | + + + + +## 参考文献 + +1.Wang W, Bao H, Huang S, Dong L, Wei F. MiniLMv2: Multi-Head Self-Attention Relation Distillation for Compressing Pretrained Transformers[J]. arXiv preprint arXiv:2012.15828v2, 2021. + +2.Hou L, Huang Z, Shang L, Jiang X, Chen X and Liu Q. DynaBERT: Dynamic BERT with Adaptive Width and Depth[J]. arXiv preprint arXiv:2004.04037, 2020. + +3.Cai H, Gan C, Wang T, Zhang Z, and Han S. Once for all: Train one network and specialize it for efficient deployment[J]. arXiv preprint arXiv:1908.09791, 2020. + +4.Wu H, Judd P, Zhang X, Isaev M and Micikevicius P. Integer Quantization for Deep Learning Inference: Principles and Empirical Evaluation[J]. arXiv preprint arXiv:2004.09602v1, 2020. diff --git a/examples/model_compression/PP-MiniLM/data.py b/examples/model_compression/PP-MiniLM/data.py new file mode 100644 index 000000000000..4c4e9e0f5498 --- /dev/null +++ b/examples/model_compression/PP-MiniLM/data.py @@ -0,0 +1,85 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np + +from paddle.metric import Metric, Accuracy +from paddlenlp.transformers import ErnieForSequenceClassification, ErnieTokenizer +from paddlenlp.transformers import BertForSequenceClassification, BertTokenizer + +MODEL_CLASSES = { + "ernie": (ErnieForSequenceClassification, ErnieTokenizer), + "bert": (BertForSequenceClassification, BertTokenizer) +} + +METRIC_CLASSES = { + "afqmc": Accuracy, + "tnews": Accuracy, + "iflytek": Accuracy, + "ocnli": Accuracy, + "cmnli": Accuracy, + "cluewsc2020": Accuracy, + "csl": Accuracy, +} + + +def convert_example(example, + tokenizer, + label_list, + max_seq_length=512, + is_test=False): + """convert a glue example into necessary features""" + if not is_test: + # `label_list == None` is for regression task + label_dtype = "int64" if label_list else "float32" + # Get the label + label = example['label'] + label = np.array([label], dtype=label_dtype) + # Convert raw text to feature + if 'sentence' in example: + example = tokenizer(example['sentence'], max_seq_len=max_seq_length) + elif 'sentence1' in example: + example = tokenizer( + example['sentence1'], + text_pair=example['sentence2'], + max_seq_len=max_seq_length) + elif 'keyword' in example: # CSL + sentence1 = " ".join(example['keyword']) + example = tokenizer( + sentence1, text_pair=example['abst'], max_seq_len=max_seq_length) + elif 'target' in example: # wsc + text, query, pronoun, query_idx, pronoun_idx = example['text'], example[ + 'target']['span1_text'], example['target']['span2_text'], example[ + 'target']['span1_index'], example['target']['span2_index'] + text_list = list(text) + assert text[pronoun_idx:(pronoun_idx + len(pronoun) + )] == pronoun, "pronoun: {}".format(pronoun) + assert text[query_idx:(query_idx + len(query) + )] == query, "query: {}".format(query) + if pronoun_idx > query_idx: + text_list.insert(query_idx, "_") + text_list.insert(query_idx + len(query) + 1, "_") + text_list.insert(pronoun_idx + 2, "[") + text_list.insert(pronoun_idx + len(pronoun) + 2 + 1, "]") + else: + text_list.insert(pronoun_idx, "[") + text_list.insert(pronoun_idx + len(pronoun) + 1, "]") + text_list.insert(query_idx + 2, "_") + text_list.insert(query_idx + len(query) + 2 + 1, "_") + text = "".join(text_list) + example = tokenizer(text, max_seq_len=max_seq_length) + + if not is_test: + return example['input_ids'], example['token_type_ids'], label + else: + return example['input_ids'], example['token_type_ids'] diff --git a/examples/model_compression/PP-MiniLM/finetuning/export_model.py b/examples/model_compression/PP-MiniLM/finetuning/export_model.py new file mode 100644 index 000000000000..1e1e4fe459a6 --- /dev/null +++ b/examples/model_compression/PP-MiniLM/finetuning/export_model.py @@ -0,0 +1,78 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os + +import paddle + +from run_clue import MODEL_CLASSES + + +def parse_args(): + parser = argparse.ArgumentParser() + + # Required parameters + parser.add_argument( + "--model_type", + default=None, + type=str, + required=True, + help="Model type selected in the list: " + + ", ".join(MODEL_CLASSES.keys()), ) + parser.add_argument( + "--model_path", + default=None, + type=str, + required=True, + help="Path of the trained model to be exported.", ) + parser.add_argument( + "--output_path", + default=None, + type=str, + required=True, + help="The output file prefix used to save the exported inference model.", + ) + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + + args.model_type = args.model_type.lower() + model_class, tokenizer_class = MODEL_CLASSES[args.model_type] + + # build model and load trained parameters + model = model_class.from_pretrained(args.model_path) + # switch to eval model + model.eval() + # convert to static graph with specific input description + model = paddle.jit.to_static( + model, + input_spec=[ + paddle.static.InputSpec( + shape=[None, None], dtype="int64"), # input_ids + paddle.static.InputSpec( + shape=[None, None], dtype="int64") # segment_ids + ]) + # save converted static graph model + paddle.jit.save(model, args.output_path) + # also save tokenizer for inference usage + tokenizer = tokenizer_class.from_pretrained(args.model_path) + tokenizer.save_pretrained(os.path.dirname(args.output_path)) + + +if __name__ == "__main__": + main() diff --git a/examples/model_compression/PP-MiniLM/finetuning/run_all_search.sh b/examples/model_compression/PP-MiniLM/finetuning/run_all_search.sh new file mode 100644 index 000000000000..63f2b4002898 --- /dev/null +++ b/examples/model_compression/PP-MiniLM/finetuning/run_all_search.sh @@ -0,0 +1,35 @@ +# $1 means GENERAL_DIR + +# The penultimate parameter is the card id, this script can be changed if necessary +bash run_one_search.sh $1 afqmc 0 & +bash run_one_search.sh $1 tnews 1 & +bash run_one_search.sh $1 ifly 2 & +bash run_one_search.sh $1 ocnli 3 & +bash run_one_search.sh $1 csl 4 & +bash run_one_search.sh $1 wsc 5 & + +# Because the CMNLI data set is significantly larger than other data sets, +# It needs to be placed on different cards. +lr=1e-4 +bs=16 +sh run_clue.sh CMNLI $lr $bs 3 128 0 $1 > $1/cmnli/${lr}_${bs}_3_128.log & +bs=32 +sh run_clue.sh CMNLI $lr $bs 3 128 1 $1 > $1/cmnli/${lr}_${bs}_3_128.log & +bs=64 +sh run_clue.sh CMNLI $lr $bs 3 128 2 $1 > $1/cmnli/${lr}_${bs}_3_128.log & + +lr=5e-5 +bs=16 +sh run_clue.sh CMNLI $lr $bs 3 128 3 $1 > $1/cmnli/${lr}_${bs}_3_128.log & +bs=32 +sh run_clue.sh CMNLI $lr $bs 3 128 4 $1 > $1/cmnli/${lr}_${bs}_3_128.log & +bs=64 +sh run_clue.sh CMNLI $lr $bs 3 128 5 $1 > $1/cmnli/${lr}_${bs}_3_128.log & + +lr=3e-5 +bs=16 +sh run_clue.sh CMNLI $lr $bs 3 128 6 $1 > $1/cmnli/${lr}_${bs}_3_128.log & +bs=32 +sh run_clue.sh CMNLI $lr $bs 3 128 5 $1 > $1/cmnli/${lr}_${bs}_3_128.log & +bs=64 +sh run_clue.sh CMNLI $lr $bs 3 128 7 $1 > $1/cmnli/${lr}_${bs}_3_128.log & diff --git a/examples/model_compression/PP-MiniLM/finetuning/run_clue.py b/examples/model_compression/PP-MiniLM/finetuning/run_clue.py new file mode 100644 index 000000000000..de08f57fe0ea --- /dev/null +++ b/examples/model_compression/PP-MiniLM/finetuning/run_clue.py @@ -0,0 +1,387 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging +import os +import sys +import random +import time +import math +import distutils.util +from functools import partial + +import numpy as np +import paddle +from paddle.io import DataLoader +import paddle.nn as nn +from paddle.metric import Accuracy + +from paddlenlp.datasets import load_dataset +from paddlenlp.data import Stack, Tuple, Pad, Dict +from paddlenlp.transformers import BertForSequenceClassification, BertTokenizer, BertModel +from paddlenlp.transformers import ErnieForSequenceClassification, ErnieTokenizer +from paddlenlp.transformers import LinearDecayWithWarmup + +sys.path.append("../") +from data import convert_example, METRIC_CLASSES, MODEL_CLASSES + +FORMAT = '%(asctime)s-%(levelname)s: %(message)s' +logging.basicConfig(level=logging.INFO, format=FORMAT) +logger = logging.getLogger(__name__) + + +def parse_args(): + parser = argparse.ArgumentParser() + + # Required parameters + parser.add_argument( + "--task_name", + default=None, + type=str, + required=True, + help="The name of the task to train selected in the list: " + + ", ".join(METRIC_CLASSES.keys()), ) + parser.add_argument( + "--model_type", + default=None, + type=str, + required=True, + help="Model type selected in the list: " + + ", ".join(MODEL_CLASSES.keys()), ) + parser.add_argument( + "--model_name_or_path", + default=None, + type=str, + required=True, + help="Path to pre-trained model or shortcut name selected in the list: " + + ", ".join( + sum([ + list(classes[-1].pretrained_init_configuration.keys()) + for classes in MODEL_CLASSES.values() + ], [])), ) + parser.add_argument( + "--output_dir", + default="best_clue_model", + type=str, + required=True, + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--max_seq_length", + default=128, + type=int, + help="The maximum total input sequence length after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded.", ) + parser.add_argument( + "--learning_rate", + default=1e-4, + type=float, + help="The initial learning rate for Adam.") + parser.add_argument( + "--num_train_epochs", + default=3, + type=int, + help="Total number of training epochs to perform.", ) + parser.add_argument( + "--logging_steps", + type=int, + default=100, + help="Log every X updates steps.") + parser.add_argument( + "--save_steps", + type=int, + default=100, + help="Save checkpoint every X updates steps.") + parser.add_argument( + "--batch_size", + default=32, + type=int, + help="Batch size per GPU/CPU for training.", ) + parser.add_argument( + "--weight_decay", + default=0.0, + type=float, + help="Weight decay if we apply some.") + parser.add_argument( + "--warmup_steps", + default=0, + type=int, + help="Linear warmup over warmup_steps. If > 0: Override warmup_proportion" + ) + parser.add_argument( + "--warmup_proportion", + default=0.1, + type=float, + help="Linear warmup proportion over total steps.") + parser.add_argument( + "--adam_epsilon", + default=1e-6, + type=float, + help="Epsilon for Adam optimizer.") + parser.add_argument( + "--do_train", + type=distutils.util.strtobool, + default=True, + help="Whether do train.") + parser.add_argument( + "--max_steps", + default=-1, + type=int, + help="If > 0: set total number of training steps to perform. Override num_train_epochs.", + ) + parser.add_argument( + "--seed", default=42, type=int, help="random seed for initialization") + parser.add_argument( + "--device", + default="gpu", + type=str, + help="The device to select to train the model, is must be cpu/gpu/xpu.") + parser.add_argument( + "--max_grad_norm", + default=1.0, + type=float, + help="The max value of grad norm.") + args = parser.parse_args() + return args + + +def set_seed(args): + # Use the same data seed(for data shuffle) for all procs to guarantee data + # consistency after sharding. + random.seed(args.seed) + np.random.seed(args.seed) + # Maybe different op seeds(for dropout) for different procs is better. By: + # `paddle.seed(args.seed + paddle.distributed.get_rank())` + paddle.seed(args.seed) + + +@paddle.no_grad() +def evaluate(model, loss_fct, metric, data_loader): + model.eval() + metric.reset() + for batch in data_loader: + input_ids, segment_ids, labels = batch + logits = model(input_ids, segment_ids) + loss = loss_fct(logits, labels) + correct = metric.compute(logits, labels) + metric.update(correct) + res = metric.accumulate() + print("eval loss: %f, acc: %s, " % (loss.numpy(), res), end='') + model.train() + return res + + +def do_eval(args): + paddle.set_device(args.device) + if paddle.distributed.get_world_size() > 1: + paddle.distributed.init_parallel_env() + + set_seed(args) + + args.task_name = args.task_name.lower() + metric_class = METRIC_CLASSES[args.task_name] + args.model_type = args.model_type.lower() + model_class, tokenizer_class = MODEL_CLASSES[args.model_type] + + dev_ds = load_dataset('clue', args.task_name, splits='dev') + + tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path) + + trans_func = partial( + convert_example, + tokenizer=tokenizer, + label_list=dev_ds.label_list, + max_seq_length=args.max_seq_length) + + dev_ds = dev_ds.map(trans_func, lazy=True) + dev_batch_sampler = paddle.io.BatchSampler( + dev_ds, batch_size=args.batch_size, shuffle=False) + + batchify_fn = lambda samples, fn=Tuple( + Pad(axis=0, pad_val=tokenizer.pad_token_id), # input + Pad(axis=0, pad_val=tokenizer.pad_token_type_id), # segment + Stack(dtype="int64" if dev_ds.label_list else "float32") # label + ): fn(samples) + + dev_data_loader = DataLoader( + dataset=dev_ds, + batch_sampler=dev_batch_sampler, + collate_fn=batchify_fn, + num_workers=0, + return_list=True) + + num_classes = 1 if dev_ds.label_list == None else len(dev_ds.label_list) + model = model_class.from_pretrained( + args.model_name_or_path, num_classes=num_classes) + if paddle.distributed.get_world_size() > 1: + model = paddle.DataParallel(model) + + metric = metric_class() + best_acc = 0.0 + global_step = 0 + tic_train = time.time() + model.eval() + metric.reset() + for batch in dev_data_loader: + input_ids, segment_ids, labels = batch + logits = model(input_ids, segment_ids) + correct = metric.compute(logits, labels) + metric.update(correct) + res = metric.accumulate() + print("acc: %s\n, " % (res), end='') + + +def do_train(args): + paddle.set_device(args.device) + if paddle.distributed.get_world_size() > 1: + paddle.distributed.init_parallel_env() + + set_seed(args) + + args.task_name = args.task_name.lower() + metric_class = METRIC_CLASSES[args.task_name] + args.model_type = args.model_type.lower() + model_class, tokenizer_class = MODEL_CLASSES[args.model_type] + + train_ds = load_dataset('clue', args.task_name, splits='train') + + tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path) + + trans_func = partial( + convert_example, + tokenizer=tokenizer, + label_list=train_ds.label_list, + max_seq_length=args.max_seq_length) + train_ds = train_ds.map(trans_func, lazy=True) + train_batch_sampler = paddle.io.DistributedBatchSampler( + train_ds, batch_size=args.batch_size, shuffle=True) + batchify_fn = lambda samples, fn=Tuple( + Pad(axis=0, pad_val=tokenizer.pad_token_id), # input + Pad(axis=0, pad_val=tokenizer.pad_token_type_id), # segment + Stack(dtype="int64" if train_ds.label_list else "float32") # label + ): fn(samples) + train_data_loader = DataLoader( + dataset=train_ds, + batch_sampler=train_batch_sampler, + collate_fn=batchify_fn, + num_workers=0, + return_list=True) + + dev_ds = load_dataset('clue', args.task_name, splits='dev') + + dev_ds = dev_ds.map(trans_func, lazy=True) + dev_batch_sampler = paddle.io.BatchSampler( + dev_ds, batch_size=args.batch_size, shuffle=False) + dev_data_loader = DataLoader( + dataset=dev_ds, + batch_sampler=dev_batch_sampler, + collate_fn=batchify_fn, + num_workers=0, + return_list=True) + + num_classes = 1 if train_ds.label_list == None else len(train_ds.label_list) + model = model_class.from_pretrained( + args.model_name_or_path, num_classes=num_classes) + if paddle.distributed.get_world_size() > 1: + model = paddle.DataParallel(model) + + if args.max_steps > 0: + num_training_steps = args.max_steps + num_train_epochs = math.ceil(num_training_steps / + len(train_data_loader)) + else: + num_training_steps = len(train_data_loader) * args.num_train_epochs + num_train_epochs = args.num_train_epochs + + warmup = args.warmup_steps if args.warmup_steps > 0 else args.warmup_proportion + + lr_scheduler = LinearDecayWithWarmup(args.learning_rate, num_training_steps, + warmup) + + # Generate parameter names needed to perform weight decay. + # All bias and LayerNorm parameters are excluded. + decay_params = [ + p.name for n, p in model.named_parameters() + if not any(nd in n for nd in ["bias", "norm"]) + ] + optimizer = paddle.optimizer.AdamW( + learning_rate=lr_scheduler, + beta1=0.9, + beta2=0.999, + epsilon=args.adam_epsilon, + parameters=model.parameters(), + weight_decay=args.weight_decay, + apply_decay_param_fun=lambda x: x in decay_params, + grad_clip=nn.ClipGradByGlobalNorm(args.max_grad_norm)) + + loss_fct = paddle.nn.loss.CrossEntropyLoss( + ) if train_ds.label_list else paddle.nn.loss.MSELoss() + + metric = metric_class() + best_acc = 0.0 + global_step = 0 + tic_train = time.time() + for epoch in range(num_train_epochs): + for step, batch in enumerate(train_data_loader): + global_step += 1 + input_ids, segment_ids, labels = batch + logits = model(input_ids, segment_ids) + loss = loss_fct(logits, labels) + loss.backward() + optimizer.step() + lr_scheduler.step() + optimizer.clear_grad() + if global_step % args.logging_steps == 0: + print( + "global step %d/%d, epoch: %d, batch: %d, rank_id: %s, loss: %f, lr: %.10f, speed: %.4f step/s" + % (global_step, num_training_steps, epoch, step, + paddle.distributed.get_rank(), loss, optimizer.get_lr(), + args.logging_steps / (time.time() - tic_train))) + tic_train = time.time() + if global_step % args.save_steps == 0 or global_step == num_training_steps: + tic_eval = time.time() + acc = evaluate(model, loss_fct, metric, dev_data_loader) + print("eval done total : %s s" % (time.time() - tic_eval)) + if acc > best_acc: + best_acc = acc + output_dir = args.output_dir + if not os.path.exists(output_dir): + os.makedirs(output_dir) + # Need better way to get inner model of DataParallel + model_to_save = model._layers if isinstance( + model, paddle.DataParallel) else model + model_to_save.save_pretrained(output_dir) + tokenizer.save_pretrained(output_dir) + if global_step >= num_training_steps: + print("best_acc: ", best_acc) + return + print("best_acc: ", best_acc) + + +def print_arguments(args): + """print arguments""" + print('----------- Configuration Arguments -----------') + for arg, value in sorted(vars(args).items()): + print('%s: %s' % (arg, value)) + print('------------------------------------------------') + + +if __name__ == "__main__": + args = parse_args() + print_arguments(args) + if args.do_train: + do_train(args) + else: + do_eval(args) diff --git a/examples/model_compression/PP-MiniLM/finetuning/run_clue.sh b/examples/model_compression/PP-MiniLM/finetuning/run_clue.sh new file mode 100644 index 000000000000..ad74187f5a4b --- /dev/null +++ b/examples/model_compression/PP-MiniLM/finetuning/run_clue.sh @@ -0,0 +1,25 @@ + +export TASK_NAME=$1 +export LR=$2 +export BS=$3 +export EPOCH=$4 +export MAX_SEQ_LEN=$5 +export CUDA_VISIBLE_DEVICES=$6 +export MODEL_PATH=$7 + +python -u ./run_clue.py \ + --model_type ernie \ + --model_name_or_path ${MODEL_PATH} \ + --task_name ${TASK_NAME} \ + --max_seq_length ${MAX_SEQ_LEN} \ + --batch_size ${BS} \ + --learning_rate ${LR} \ + --num_train_epochs ${EPOCH} \ + --logging_steps 100 \ + --seed 42 \ + --save_steps 100 \ + --warmup_proportion 0.1 \ + --weight_decay 0.01 \ + --adam_epsilon 1e-8 \ + --output_dir ${MODEL_PATH}/models/${TASK_NAME}/${LR}_${BS}/ \ + --device gpu \ diff --git a/examples/model_compression/PP-MiniLM/finetuning/run_one_search.sh b/examples/model_compression/PP-MiniLM/finetuning/run_one_search.sh new file mode 100644 index 000000000000..fbb5261d2f31 --- /dev/null +++ b/examples/model_compression/PP-MiniLM/finetuning/run_one_search.sh @@ -0,0 +1,55 @@ +OUTPUT_DIR=$1 +TASK_NAME=$2 + +mkdir ${OUTPUT_DIR}/afqmc +mkdir ${OUTPUT_DIR}/tnews +mkdir ${OUTPUT_DIR}/ifly +mkdir ${OUTPUT_DIR}/ocnli +mkdir ${OUTPUT_DIR}/wsc +mkdir ${OUTPUT_DIR}/csl +mkdir ${OUTPUT_DIR}/cmnli + + +for lr in 1e-4 5e-5 3e-5 +do + for bs in 16 32 64 + do + echo bs: $bs, lr: $lr + + if [ $TASK_NAME == afqmc ] + then + sh run_clue.sh AFQMC $lr $bs 3 128 $3 ${OUTPUT_DIR} > ${OUTPUT_DIR}/afqmc/${lr}_${bs}_3_128.log + fi + + if [ $TASK_NAME == tnews ] + then + sh run_clue.sh TNEWS $lr $bs 3 128 $3 ${OUTPUT_DIR} > ${OUTPUT_DIR}/tnews/${lr}_${bs}_3_128.log + fi + + if [ $TASK_NAME == ifly ] + then + sh run_clue.sh IFLYTEK $lr $bs 6 128 $3 ${OUTPUT_DIR} > ${OUTPUT_DIR}/ifly/${lr}_${bs}_6_128.log + fi + + if [ $TASK_NAME == ocnli ] + then + sh run_clue.sh OCNLI $lr $bs 6 128 $3 ${OUTPUT_DIR} > ${OUTPUT_DIR}/ocnli/${lr}_${bs}_6_128.log + fi + + if [ $TASK_NAME == wsc ] + then + sh run_clue.sh CLUEWSC2020 $lr $bs 50 128 $3 ${OUTPUT_DIR} > ${OUTPUT_DIR}/wsc/${lr}_${bs}_50_128.log + fi + + if [ $TASK_NAME == csl ] + then + sh run_clue.sh CSL $lr $bs 8 256 $3 ${OUTPUT_DIR} > ${OUTPUT_DIR}/csl/${lr}_${bs}_8_256.log + fi + + if [ $TASK_NAME == cmnli ] + then + sh run_clue.sh CMNLI $lr $bs 3 128 $3 ${OUTPUT_DIR} > ${OUTPUT_DIR}/cmnli/${lr}_${bs}_3_128.log + fi + done +done + diff --git a/examples/model_compression/PP-MiniLM/general_distill/README.md b/examples/model_compression/PP-MiniLM/general_distill/README.md new file mode 100644 index 000000000000..df8767f5a50a --- /dev/null +++ b/examples/model_compression/PP-MiniLM/general_distill/README.md @@ -0,0 +1,64 @@ +# PP-MiniLM 任务无关蒸馏 + +## 环境要求 + +本实验基于 NVIDIA Tesla V100 32G 8 卡进行,训练周期约为 2-3 天。 + +## 原理介绍 + +任务无关知识蒸馏是用较大(层数更多、宽度更宽的)的基于 Transformer Layer 的预训练模型对较小(层数更少、宽度更窄的)的基于 Transformer Layer 的预训练模型进行蒸馏,从而得到更小、效果与较大模型更接近的预训练模型。 + +PP-MiniLM 参考了 MiniLMv2 提出的 Multi-Head Self-Attention Relation Distillation 蒸馏策略。MiniLMv2 算法是用 24 层 large-size 的教师模型倒数几层的 Q-Q、K-K、V-V 之间的relation对6层学生模型最后一层 Q-Q、K-K、V-V 之间的relation进行蒸馏。具体的做法是,首先将学生、教师用于蒸馏的层上的 Q、K、V 的 Head 数进行统一,然后计算各自 Q—Q、K-K、V-V 的点积,最后对教师和学生的点积计算KL散度损失。由于relation的shape是 `[batch_size, head_num, seq_len, seq_len]`,因此可以认为这里的relation是一种Token与Token之间的关系。 + +本方案在MiniLMv2策略的基础上,做了进一步优化: 通过引入多视角的注意力关系知识来进一步提升模型效果。MiniLMv2 的自注意力关系知识仅建模了 Token 与 Token 之间的关系,PP-MiniLM 在此基础上额外引入了样本与样本间的自注意力关系知识,也就是挖掘出更多教师模型所蕴含的知识,从而进一步优化模型效果。 + +具体来说,PP-MiniLM 利用了 `roberta-wwm-ext-large` 第 20 层的 Q-Q、K-K、V-V 之间的 Sample 与 Sampl 之间关系对 6 层学生模型 PP-MiniLM 第 6 层的 Q-Q、K-K、V-V 之间的 Sample 与 Sample 之间的关系进行蒸馏。与MiniLMv2不同的是,PP-MiniLM的策略需要在统一Q、K、V的Head数之后,对Q、K、V转置为 `[seq_len, head_num, batch_size, head_dim]`,这样Q—Q、K- K、V-V 的点积则可以表达样本间的关系。经过我们的实验,这种方法比使用原始 MiniLMv2 算法在 CLUE 上平均准确率高 0.36。 + + +### 数据介绍 + +任务无关知识蒸馏的训练数据一般是预训练语料,可以使用公开的预训练语料 [CLUECorpus2020](https://github.com/CLUEbenchmark/CLUECorpus2020/)。需要将数据处理成一行一个句子的格式,再将数据文件分割成多个子文件(例如 64 个),放在同一个目录下。 + + +### 运行方式 + +```shell +sh run.sh # 包含general_distill.py的运行配置 +cd .. +``` + +其中 `general_distill.py` 参数释义如下: + +- `model_type` 指示了学生模型类型,当前仅支持 'ernie'、'roberta'。 +- `num_relation_heads` relation head 的个数,一般对于 large-size 的教师模型是64,对于 base-size 的教师模型是 48。 +- `teacher_model_type`指示了教师模型类型,当前仅支持 'ernie'、'roberta'。 +- `teacher_layer_index`蒸馏时使用的教师模型的层 +- `student_layer_index` 蒸馏时使用的学生模型的层 +- `teacher_model_name_or_path`教师模型的名称,例如`'roberta-wwm-ext-large'` +- `max_seq_length` 最大的样本长度 +- `num_layers` 学生模型的层数,目前仅支持 2,4,6 +- `logging_steps` 日志间隔 +- `max_steps` 最大迭代次数 +- `warmup_steps` 学习率增长得到`learning_rate`所需要的步数 +- `save_steps`保存模型的间隔步数 +- `weight_decay` 表示AdamW优化器中使用的 weight_decay 的系数 +- `output_dir`训练相关文件以及模型保存的输出路径 +- `device`设备选择,推荐使用 gpu +- `input_dir` 训练数据目录 +- `use_amp` 是否使用混合精度训练,默认 False +- `alpha`head间关系的权重,默认 0.0 +- `beta`样本间关系的权重,默认 0.0 + +将最终得到的模型绝对路径保存至 `$GENERAL_MODEL_DIR`,例如: + +```shell +GENERAL_MODEL_DIR=PaddleNLP/examples/model_compression/PP-MiniLM/general_distill/pretrain/model_400000 +``` + +## 模型精度 + +在 CLUE 数据集上经过超参寻优后,得到 CLUE 上各个任务上的最高准确率如下表: + +| AFQMC | TNEWS | IFLYTEK | CMNLI | OCNLI | WSC | CSL | CLUE 平均值 | +| ----- | ----- | ------- | ----- | ----- | ----- | ----- | ---------- | +| 74.28 | 57.33 | 61.72 | 81.06 | 76.20 | 86.51 | 78.77 | 73.70 | diff --git a/examples/model_compression/PP-MiniLM/general_distill/general_distill.py b/examples/model_compression/PP-MiniLM/general_distill/general_distill.py new file mode 100644 index 000000000000..81f04f5e889f --- /dev/null +++ b/examples/model_compression/PP-MiniLM/general_distill/general_distill.py @@ -0,0 +1,491 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import random +import time +from functools import partial +from concurrent.futures import ThreadPoolExecutor +import distutils.util +import math + +import numpy as np +import paddle +from paddle.io import DataLoader +import paddle.nn.functional as F +from paddle import tensor + +from paddlenlp.utils.log import logger +from paddlenlp.data import Tuple, Pad +from paddlenlp.utils.tools import TimeCostAverage +from paddlenlp.transformers import LinearDecayWithWarmup +from paddlenlp.transformers import RobertaModel, RobertaTokenizer +from paddlenlp.transformers import ErnieModel, ErnieForSequenceClassification, ErnieTokenizer +from paddlenlp.transformers.distill_utils import to_distill, calc_multi_relation_loss + +MODEL_CLASSES = { + "roberta": (RobertaModel, RobertaTokenizer), + "ernie": (ErnieForSequenceClassification, ErnieTokenizer) +} + + +def parse_args(): + parser = argparse.ArgumentParser() + + # Required parameters + parser.add_argument( + "--model_type", + default="ernie", + type=str, + required=True, + help="Model type selected in the list: " + + ", ".join(MODEL_CLASSES.keys()), ) + parser.add_argument( + "--teacher_model_type", + default="ernie", + type=str, + required=True, + help="Model type selected in the list: " + + ", ".join(MODEL_CLASSES.keys()), ) + parser.add_argument( + "--student_model_name_or_path", + default=None, + type=str, + required=False, + help="Path to pre-trained model or shortcut name selected in the list: " + + ", ".join( + sum([ + list(classes[-1].pretrained_init_configuration.keys()) + for classes in MODEL_CLASSES.values() + ], [])), ) + parser.add_argument( + "--teacher_model_name_or_path", + default=None, + type=str, + required=True, + help="Path to pre-trained model.") + parser.add_argument( + "--input_dir", + default=None, + type=str, + required=True, + help="The input directory where the data will be read from.", ) + parser.add_argument( + "--output_dir", + default=None, + type=str, + required=True, + help="The output directory where the model predictions and checkpoints will be written.", + ) + + parser.add_argument( + "--max_seq_length", + default=128, + type=int, + help="The maximum total input sequence length after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded.", ) + parser.add_argument( + "--learning_rate", + default=6e-4, + type=float, + help="The initial learning rate for Adam.") + parser.add_argument( + "--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument( + "--num_layers", + default=6, + type=int, + help="Number layers of student model.", ) + parser.add_argument( + "--teacher_layer_index", + default=19, + type=int, + help="Total number of training epochs to perform.", ) + parser.add_argument( + "--student_layer_index", + default=5, + type=int, + help="Total number of training epochs to perform.", ) + parser.add_argument( + "--num_train_epochs", + default=3, + type=int, + help="Total number of training epochs to perform.", ) + parser.add_argument( + "--logging_steps", + type=int, + default=100, + help="Log every X updates steps.") + parser.add_argument( + "--save_steps", + type=int, + default=100, + help="Save checkpoint every X updates steps.") + parser.add_argument( + "--batch_size", + default=512, + type=int, + help="Batch size per GPU/CPU for training.", ) + parser.add_argument( + "--num_relation_heads", + default=64, + type=int, + help="The number of relation heads is 48 and 64 for base and large-size teacher model.", + ) + parser.add_argument("--beta", default=0.0, type=float, help="0.0 usually") + parser.add_argument("--alpha", default=0.0, type=float, help="0.0 usually") + parser.add_argument( + "--weight_decay", + default=0.01, + type=float, + help="Weight decay if we apply some.") + parser.add_argument( + "--warmup_steps", + default=-1, + type=int, + help="Linear warmup over warmup_steps. If > 0: Override warmup_proportion" + ) + parser.add_argument( + "--warmup_proportion", + default=0.01, + type=float, + help="Linear warmup proportion over total steps.") + parser.add_argument( + "--adam_epsilon", + default=1e-8, + type=float, + help="Epsilon for Adam optimizer.") + parser.add_argument( + "--max_steps", + default=400000, + type=int, + help="If > 0: set total number of training steps to perform. Override num_train_epochs.", + ) + parser.add_argument( + "--seed", default=42, type=int, help="random seed for initialization") + parser.add_argument( + "--device", + default="gpu", + type=str, + help="The device to select to train the model, is must be cpu/gpu/xpu.") + parser.add_argument( + "--use_amp", + type=distutils.util.strtobool, + default=False, + help="Enable mixed precision training.") + parser.add_argument( + "--scale_loss", + type=float, + default=2**15, + help="The value of scale_loss for fp16.") + args = parser.parse_args() + return args + + +def set_seed(args): + random.seed(args.seed + paddle.distributed.get_rank()) + np.random.seed(args.seed + paddle.distributed.get_rank()) + paddle.seed(args.seed + paddle.distributed.get_rank()) + + +class WorkerInitObj(object): + def __init__(self, seed): + self.seed = seed + + def __call__(self, id): + np.random.seed(seed=self.seed + id) + random.seed(self.seed + id) + + +def create_pretraining_dataset(input_file, shared_list, args, worker_init, + tokenizer): + train_data = PretrainingDataset( + input_file=input_file, + tokenizer=tokenizer, + max_seq_length=args.max_seq_length) + # files have been sharded, no need to dispatch again + train_batch_sampler = paddle.io.BatchSampler( + train_data, batch_size=args.batch_size, shuffle=True) + + # DataLoader cannot be pickled because of its place. + # If it can be pickled, use global function instead of lambda and use + # ProcessPoolExecutor instead of ThreadPoolExecutor to prefetch. + batchify_fn = lambda samples, fn=Tuple( + Pad(axis=0, pad_val=tokenizer.pad_token_id), # input + ): fn(samples) + + train_data_loader = DataLoader( + dataset=train_data, + batch_sampler=train_batch_sampler, + collate_fn=batchify_fn, + num_workers=0, + worker_init_fn=worker_init, + return_list=True) + return train_data_loader, input_file + + +class PretrainingDataset(paddle.io.Dataset): + def __init__(self, input_file, tokenizer, max_seq_length): + self.input_file = input_file + f = open(input_file, 'r') + input_ids = [] + for i, line in enumerate(f): + line = line[:max_seq_length] + tokenized_example = tokenizer(line, max_seq_len=max_seq_length) + input_ids.append(tokenized_example['input_ids']) + + self.inputs = np.asarray(input_ids) + f.close() + + def __len__(self): + 'Denotes the total number of samples' + return len(self.inputs) + + def __getitem__(self, index): + input_ids = [np.asarray(self.inputs[index])] + return input_ids + + +def do_train(args): + paddle.set_device(args.device) + if paddle.distributed.get_world_size() > 1: + paddle.distributed.init_parallel_env() + + set_seed(args) + + worker_init = WorkerInitObj(args.seed + paddle.distributed.get_rank()) + args.model_type = args.model_type.lower() + + # For teacher + teacher_model_class, tokenizer_class = MODEL_CLASSES[ + args.teacher_model_type] + tokenizer = tokenizer_class.from_pretrained(args.teacher_model_name_or_path) + + # For student + model_class, _ = MODEL_CLASSES[args.model_type] + if args.num_layers == 6: + ernie = ErnieModel( + vocab_size=tokenizer.vocab_size, + num_hidden_layers=6, + hidden_act='relu', + intermediate_size=3072, + hidden_size=768) # layer: 6 + elif args.num_layers == 4: + ernie = ErnieModel( + vocab_size=tokenizer.vocab_size, + num_hidden_layers=4, + hidden_act='relu', + intermediate_size=1024, + hidden_size=256, + num_attention_heads=16) # layer: 4 + else: + ernie = ErnieModel( + vocab_size=tokenizer.vocab_size, + num_hidden_layers=2, + hidden_act='relu', + hidden_size=128, + intermediate_size=512) # layer: 2 + student = model_class(ernie) + + teacher = teacher_model_class.from_pretrained( + args.teacher_model_name_or_path) + pad_token_id = 0 + + if paddle.distributed.get_world_size() > 1: + student = paddle.DataParallel(student, find_unused_parameters=True) + teacher = paddle.DataParallel(teacher, find_unused_parameters=True) + + num_training_steps = args.max_steps + + warmup = args.warmup_steps if args.warmup_steps > 0 else args.warmup_proportion + lr_scheduler = LinearDecayWithWarmup(args.learning_rate, num_training_steps, + warmup) + + # Generate parameter names needed to perform weight decay. + # All bias and LayerNorm parameters are excluded. + decay_params = [ + p.name for n, p in student.named_parameters() + if not any(nd in n for nd in ["bias", "norm"]) + ] + + optimizer = paddle.optimizer.AdamW( + learning_rate=lr_scheduler, + beta1=0.9, + beta2=0.999, + epsilon=args.adam_epsilon, + parameters=student.parameters(), + weight_decay=args.weight_decay, + apply_decay_param_fun=lambda x: x in decay_params, + grad_clip=paddle.nn.ClipGradByGlobalNorm(args.max_grad_norm)) + + if args.use_amp: + scaler = paddle.amp.GradScaler(init_loss_scaling=args.scale_loss) + pool = ThreadPoolExecutor(1) + + teacher = to_distill( + teacher, return_qkv=True, layer_index=args.teacher_layer_index) + student = to_distill( + student, return_qkv=True, layer_index=args.student_layer_index) + + global_step = 0 + tic_train = time.time() + for epoch in range(args.num_train_epochs): + files = [ + os.path.join(args.input_dir, f) for f in os.listdir(args.input_dir) + if os.path.isfile(os.path.join(args.input_dir, f)) + ] + files.sort() + num_files = len(files) + random.Random(args.seed + epoch).shuffle(files) + f_start_id = 0 + + shared_file_list = {} + + if paddle.distributed.get_world_size() > num_files: + remainder = paddle.distributed.get_world_size() % num_files + + data_file = files[( + f_start_id * paddle.distributed.get_world_size() + + paddle.distributed.get_rank() + remainder * f_start_id) % + num_files] + else: + data_file = files[(f_start_id * paddle.distributed.get_world_size() + + paddle.distributed.get_rank()) % num_files] + + previous_file = data_file + + train_data_loader, _ = create_pretraining_dataset( + data_file, shared_file_list, args, worker_init, tokenizer) + + # TODO(guosheng): better way to process single file + single_file = True if f_start_id + 1 == len(files) else False + + for f_id in range(f_start_id, len(files)): + if not single_file and f_id == f_start_id: + continue + if paddle.distributed.get_world_size() > num_files: + data_file = files[( + f_id * paddle.distributed.get_world_size() + + paddle.distributed.get_rank() + remainder * f_id) % + num_files] + else: + data_file = files[(f_id * paddle.distributed.get_world_size() + + paddle.distributed.get_rank()) % num_files] + previous_file = data_file + dataset_future = pool.submit(create_pretraining_dataset, data_file, + shared_file_list, args, worker_init, + tokenizer) + + kl_loss_fct = paddle.nn.KLDivLoss('sum') + train_cost_avg = TimeCostAverage() + total_samples = 0 + batch_start = time.time() + for step, batch in enumerate(train_data_loader): + global_step += 1 + input_ids = batch[0] + attention_mask = paddle.unsqueeze( + (input_ids == pad_token_id + ).astype(paddle.get_default_dtype()) * -1e4, + axis=[1, 2]) + with paddle.amp.auto_cast( + args.use_amp, + custom_white_list=["layer_norm", "gelu", "softmax"]): + student(input_ids) + with paddle.no_grad(): + teacher(input_ids) + # Q-Q relation + q_t, q_s = teacher.outputs.q, student.outputs.q + batch_size = q_t.shape[0] + pad_seq_len = q_t.shape[2] + loss_q = calc_multi_relation_loss( + kl_loss_fct, q_s, q_t, attention_mask, + args.num_relation_heads, args.alpha, args.beta) + del q_t, q_s + # K-K relation + k_t, k_s = teacher.outputs.k, student.outputs.k + loss_k = calc_multi_relation_loss( + kl_loss_fct, k_s, k_t, attention_mask, + args.num_relation_heads, args.alpha, args.beta) + del k_t, k_s + + # V-V relation + v_t, v_s = teacher.outputs.v, student.outputs.v + loss_v = calc_multi_relation_loss( + kl_loss_fct, v_s, v_t, attention_mask, + args.num_relation_heads, args.alpha, args.beta) + + del v_t, v_s + + loss = loss_q + loss_k + loss_v + loss /= args.num_relation_heads * pad_seq_len * batch_size + + if args.use_amp: + scaler.scale(loss).backward() + scaler.minimize(optimizer, loss) + else: + loss.backward() + + optimizer.step() + lr_scheduler.step() + optimizer.clear_grad() + + total_samples += args.batch_size + train_run_cost = time.time() - batch_start + train_cost_avg.record(train_run_cost) + if global_step % args.logging_steps == 0: + logger.info( + "global step: %d, epoch: %d, batch: %d, loss: %f, " + "lr: %f, avg_batch_cost: %.5f sec, avg_samples: %.5f, ips: %.5f sequences/sec" + % (global_step, epoch, step, loss, optimizer.get_lr(), + train_cost_avg.get_average(), + total_samples / args.logging_steps, total_samples / + (args.logging_steps * train_cost_avg.get_average()))) + total_samples = 0 + train_cost_avg.reset() + if global_step % args.save_steps == 0 or global_step == num_training_steps: + if paddle.distributed.get_rank() == 0: + output_dir = os.path.join(args.output_dir, + "model_%d" % global_step) + if not os.path.exists(output_dir): + os.makedirs(output_dir) + # need better way to get inner model of DataParallel + model_to_save = student._layers if isinstance( + student, paddle.DataParallel) else student + model_to_save.save_pretrained(output_dir) + tokenizer.save_pretrained(output_dir) + paddle.save( + optimizer.state_dict(), + os.path.join(output_dir, "model_state.pdopt")) + if global_step >= args.max_steps: + del train_data_loader + return + batch_start = time.time() + + del train_data_loader + train_data_loader, data_file = dataset_future.result(timeout=None) + + +def print_arguments(args): + """print arguments""" + print('----------- Configuration Arguments -----------') + for arg, value in sorted(vars(args).items()): + print('%s: %s' % (arg, value)) + print('------------------------------------------------') + + +if __name__ == "__main__": + args = parse_args() + print_arguments(args) + do_train(args) diff --git a/examples/model_compression/PP-MiniLM/general_distill/run.sh b/examples/model_compression/PP-MiniLM/general_distill/run.sh new file mode 100644 index 000000000000..3db0d135973b --- /dev/null +++ b/examples/model_compression/PP-MiniLM/general_distill/run.sh @@ -0,0 +1,70 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +set -eux + +unset CUDA_VISIBLE_DEVICES + +bs=128 +maxlen=128 +numH=64 +lr=6e-4 +maxStep=400000 +warmStep=4000 +wd=1e-2 + +teacher=roberta +teacherModel=roberta-wwm-ext-large + +alpha=0 +beta=1.0 +mode=hardest +use_amp=True +teacher_layer_index=19 +student_layer_index=5 +num_layers=6 + +hp_config=bs${bs}_maxlen${maxlen}_lr${lr}_wd${wd}_numH${numH}_maxStep${maxStep}_warmStep${warmStep}_adamW_maxnorm1p0_teacher_${teacherModel}_coldboot_teacher_vocab_index${teacher_layer_index}_4l-312d-batchbatch + +export PYTHONPATH=../../../../:$PYTHONPATH +output_dir="./pretrain_${hp_config}" + +mkdir -p ${output_dir} +cp ./general_distill.py ${output_dir}/ +cp ../../../../paddlenlp/transformers/distill_utils.py ${output_dir}/ + + +python3 -m paddle.distributed.launch --gpus "0,1,2,3,4,5,6,7" general_distill.py \ + --model_type ernie \ + --num_relation_heads ${numH} \ + --teacher_model_type ${teacher} \ + --teacher_layer_index ${teacher_layer_index} \ + --student_layer_index ${student_layer_index} \ + --teacher_model_name_or_path ${teacherModel} \ + --max_seq_length ${maxlen} \ + --num_layers ${num_layers} \ + --batch_size ${bs} \ + --learning_rate ${lr} \ + --logging_steps 20 \ + --max_steps ${maxStep} \ + --warmup_steps ${warmStep} \ + --save_steps 20000 \ + --weight_decay ${wd} \ + --output_dir ${output_dir} \ + --device gpu \ + --input_dir dataset/ \ + --use_amp ${use_amp} \ + --alpha ${alpha} \ + --beta ${beta} \ diff --git a/examples/model_compression/PP-MiniLM/inference/infer.py b/examples/model_compression/PP-MiniLM/inference/infer.py new file mode 100644 index 000000000000..40cf2af87ef8 --- /dev/null +++ b/examples/model_compression/PP-MiniLM/inference/infer.py @@ -0,0 +1,271 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import sys +from functools import partial +import numpy as np + +import paddle +from paddle import inference +from paddlenlp.datasets import load_dataset +from paddlenlp.data import Stack, Tuple, Pad + +sys.path.append("../") +from data import convert_example, METRIC_CLASSES, MODEL_CLASSES + + +def parse_args(): + parser = argparse.ArgumentParser() + + # Required parameters + parser.add_argument( + "--task_name", + default='afqmc', + type=str, + help="The name of the task to perform predict, selected in the list: " + + ", ".join(METRIC_CLASSES.keys()), ) + parser.add_argument( + "--model_type", + default='ernie', + type=str, + help="Model type selected in the list: " + + ", ".join(MODEL_CLASSES.keys()), ) + parser.add_argument( + "--model_name_or_path", + default='ppminilm-6l-768h', + type=str, + help="The directory or name of model.", ) + parser.add_argument( + "--model_path", + default='./quant_models/model', + type=str, + required=True, + help="The path prefix of inference model to be used.", ) + parser.add_argument( + "--device", + default="gpu", + choices=["gpu", "cpu", "xpu"], + help="Device selected for inference.", ) + parser.add_argument( + "--batch_size", + default=32, + type=int, + help="Batch size for predict.", ) + parser.add_argument( + "--max_seq_length", + default=128, + type=int, + help="The maximum total input sequence length after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded.", ) + parser.add_argument( + "--perf_warmup_steps", + default=20, + type=int, + help="Warmup steps for performance test.", ) + parser.add_argument( + "--use_trt", + action='store_true', + help="Whether to use inference engin TensorRT.", ) + parser.add_argument( + "--perf", + action='store_true', + help="Whether to test performance.", ) + parser.add_argument( + "--collect_shape", + action='store_true', + help="Whether collect shape range info.", ) + parser.add_argument( + "--int8", + action='store_true', + help="Whether to use int8 inference.", ) + args = parser.parse_args() + return args + + +@paddle.no_grad() +def evaluate(outputs, metric, data_loader): + metric.reset() + for i, batch in enumerate(data_loader): + input_ids, segment_ids, labels = batch + logits = paddle.to_tensor(outputs[i][0]) + correct = metric.compute(logits, labels) + metric.update(correct) + res = metric.accumulate() + print("acc: %s, " % res, end='') + + +class Predictor(object): + def __init__(self, predictor, input_handles, output_handles): + self.predictor = predictor + self.input_handles = input_handles + self.output_handles = output_handles + + @classmethod + def create_predictor(cls, args): + config = paddle.inference.Config(args.model_path + ".pdmodel", + args.model_path + ".pdiparams") + if args.device == "gpu": + # set GPU configs accordingly + config.enable_use_gpu(100, 0) + cls.device = paddle.set_device("gpu") + elif args.device == "cpu": + # set CPU configs accordingly, + # such as enable_mkldnn, set_cpu_math_library_num_threads + config.disable_gpu() + cls.device = paddle.set_device("cpu") + elif args.device == "xpu": + # set XPU configs accordingly + config.enable_xpu(100) + if args.use_trt: + if args.int8: + config.enable_tensorrt_engine( + workspace_size=1 << 30, + precision_mode=inference.PrecisionType.Int8, + max_batch_size=args.batch_size, + min_subgraph_size=5, + use_static=False, + use_calib_mode=False) + else: + config.enable_tensorrt_engine( + workspace_size=1 << 30, + precision_mode=inference.PrecisionType.Float32, + max_batch_size=args.batch_size, + min_subgraph_size=5, + use_static=False, + use_calib_mode=False) + print("Enable TensorRT is: {}".format( + config.tensorrt_engine_enabled())) + if args.collect_shape: + config.collect_shape_range_info( + os.path.join( + os.path.dirname(args.model_path), args.task_name + + '_shape_range_info.pbtxt')) + else: + config.enable_tuned_tensorrt_dynamic_shape( + os.path.join( + os.path.dirname(args.model_path), + args.task_name + "_shape_range_info.pbtxt"), True) + predictor = paddle.inference.create_predictor(config) + input_handles = [ + predictor.get_input_handle(name) + for name in predictor.get_input_names() + ] + output_handles = [ + predictor.get_output_handle(name) + for name in predictor.get_output_names() + ] + + return cls(predictor, input_handles, output_handles) + + def predict_batch(self, data): + for input_field, input_handle in zip(data, self.input_handles): + input_handle.copy_from_cpu(input_field.numpy() if isinstance( + input_field, paddle.Tensor) else input_field) + self.predictor.run() + output = [ + output_handle.copy_to_cpu() for output_handle in self.output_handles + ] + + return output + + def predict(self, dataset, collate_fn, args, batch_size=1): + metric = METRIC_CLASSES[args.task_name]() + batch_sampler = paddle.io.BatchSampler( + dataset, batch_size=batch_size, shuffle=False) + data_loader = paddle.io.DataLoader( + dataset=dataset, + batch_sampler=batch_sampler, + collate_fn=collate_fn, + num_workers=0, + return_list=True) + outputs = [] + metric.reset() + for i, data in enumerate(data_loader): + if len(data) == 2: + output = self.predict_batch(data) + else: + output = self.predict_batch([data[0], data[1]]) + logits = paddle.to_tensor(output) + correct = metric.compute(logits, data[2]) + metric.update(correct) + outputs.append(output) + if len(data) > 2: + res = metric.accumulate() + print("task name: %s, acc: %s, " % (args.task_name, res), end='') + + return outputs + + def predict_perf(self, dataset, collate_fn, args, batch_size=1): + batch_sampler = paddle.io.BatchSampler( + dataset, batch_size=batch_size, shuffle=False) + data_loader = paddle.io.DataLoader( + dataset=dataset, + batch_sampler=batch_sampler, + collate_fn=collate_fn, + num_workers=0, + return_list=True) + time1 = time.time() + for i, data in enumerate(data_loader): + if i < args.perf_warmup_steps: # skip warmup steps. + continue + output = self.predict_batch([data[0], data[1]]) + logits = paddle.to_tensor(output) + + print("time: ", time.time() - time1) + + +def main(): + paddle.seed(42) + args = parse_args() + + args.task_name = args.task_name.lower() + args.model_type = args.model_type.lower() + + predictor = Predictor.create_predictor(args) + + model_class, tokenizer_class = MODEL_CLASSES[args.model_type] + + dev_ds = load_dataset('clue', args.task_name, splits='dev') + tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path) + trans_func = partial( + convert_example, + tokenizer=tokenizer, + label_list=dev_ds.label_list, + max_seq_length=args.max_seq_length, + is_test=False) + + dev_ds = dev_ds.map(trans_func, lazy=True) + batchify_fn = lambda samples, fn=Tuple( + Pad(axis=0, pad_val=tokenizer.pad_token_id), # input + Pad(axis=0, pad_val=tokenizer.pad_token_type_id), # segment + Stack(dtype="int64" if dev_ds.label_list else "float32") # label + ): fn(samples) + if args.perf: + outputs = predictor.predict_perf( + dev_ds, + batch_size=args.batch_size, + collate_fn=batchify_fn, + args=args) + else: + outputs = predictor.predict( + dev_ds, + batch_size=args.batch_size, + collate_fn=batchify_fn, + args=args) + + +if __name__ == "__main__": + main() diff --git a/examples/model_compression/PP-MiniLM/inference/infer_all.sh b/examples/model_compression/PP-MiniLM/inference/infer_all.sh new file mode 100644 index 000000000000..f26680415c56 --- /dev/null +++ b/examples/model_compression/PP-MiniLM/inference/infer_all.sh @@ -0,0 +1,27 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +for task in afqmc tnews iflytek cmnli ocnli cluewsc2020 csl +do + for bs in 4 8 + do + for algo in abs_max avg hist mse + do + python infer.py --task_name ${task} --model_path ../quantization/${task}_quant_models/${algo}${bs}/int8 --int8 --use_trt --collect_shape + python infer.py --task_name ${task} --model_path ../quantization/${task}_quant_models/${algo}${bs}/int8 --int8 --use_trt + echo this is ${task}, ${algo}, ${bs} + done + done +done diff --git a/examples/model_compression/PP-MiniLM/inference/infer_perf.sh b/examples/model_compression/PP-MiniLM/inference/infer_perf.sh new file mode 100644 index 000000000000..dc37f7006584 --- /dev/null +++ b/examples/model_compression/PP-MiniLM/inference/infer_perf.sh @@ -0,0 +1,41 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +task = tnews +echo Inference of orgin FP32 model +python infer.py --task_name ${task} --model_path tnews/float --use_trt --collect_shape --perf +python infer.py --task_name ${task} --model_path tnews/float --use_trt --perf +python infer.py --task_name ${task} --model_path tnews/float --use_trt --perf +python infer.py --task_name ${task} --model_path tnews/float --use_trt --perf +python infer.py --task_name ${task} --model_path tnews/float --use_trt --perf +python infer.py --task_name ${task} --model_path tnews/float --use_trt --perf + + +echo After pruning +python infer.py --task_name ${task} --model_path ofa_models/TNEWS/0.75/sub_static/float --use_trt --collect_shape --perf +python infer.py --task_name ${task} --model_path ofa_models/TNEWS/0.75/sub_static/float --use_trt --perf +python infer.py --task_name ${task} --model_path ofa_models/TNEWS/0.75/sub_static/float --use_trt --perf +python infer.py --task_name ${task} --model_path ofa_models/TNEWS/0.75/sub_static/float --use_trt --perf +python infer.py --task_name ${task} --model_path ofa_models/TNEWS/0.75/sub_static/float --use_trt --perf +python infer.py --task_name ${task} --model_path ofa_models/TNEWS/0.75/sub_static/float --use_trt --perf + +echo After quantization +python infer.py --task_name tnews --model_path ../quantization/${task}_quant_models/mse4/int8 --int8 --use_trt --collect_shape --perf +python infer.py --task_name tnews --model_path ../quantization/${task}_quant_models/mse4/int8 --int8 --use_trt --perf +python infer.py --task_name tnews --model_path ../quantization/${task}_quant_models/mse4/int8 --int8 --use_trt --perf +python infer.py --task_name tnews --model_path ../quantization/${task}_quant_models/mse4/int8 --int8 --use_trt --perf +python infer.py --task_name tnews --model_path ../quantization/${task}_quant_models/mse4/int8 --int8 --use_trt --perf +python infer.py --task_name tnews --model_path ../quantization/${task}_quant_models/mse4/int8 --int8 --use_trt --perf + + diff --git a/examples/model_compression/PP-MiniLM/pp-minilm.png b/examples/model_compression/PP-MiniLM/pp-minilm.png new file mode 100644 index 000000000000..8fc843169788 Binary files /dev/null and b/examples/model_compression/PP-MiniLM/pp-minilm.png differ diff --git a/examples/model_compression/PP-MiniLM/pruning/export.sh b/examples/model_compression/PP-MiniLM/pruning/export.sh new file mode 100644 index 000000000000..bec1cc491ac2 --- /dev/null +++ b/examples/model_compression/PP-MiniLM/pruning/export.sh @@ -0,0 +1,21 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +MODEL_PATH=$1 +TASK_NAME=$2 +python export_model.py --model_type ernie \ + --model_name_or_path ${MODEL_PATH}/${TASK_NAME}/0.75/best_model \ + --sub_model_output_dir ${MODEL_PATH}/${TASK_NAME}/0.75/sub/ \ + --static_sub_model ${MODEL_PATH}/${TASK_NAME}/0.75/sub_static/float \ + --n_gpu 1 --width_mult 0.75 diff --git a/examples/model_compression/PP-MiniLM/pruning/export_all.sh b/examples/model_compression/PP-MiniLM/pruning/export_all.sh new file mode 100644 index 000000000000..fe730eed0094 --- /dev/null +++ b/examples/model_compression/PP-MiniLM/pruning/export_all.sh @@ -0,0 +1,26 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +MODEL_PATH=pruned_models + +for TASK_NAME in AFQMC TNEWS IFLYTEK CMNLI OCNLI CLUEWSC2020 CSL + +do + python export_model.py --model_type ernie \ + --model_name_or_path ${MODEL_PATH}/${TASK_NAME}/0.75/best_model \ + --sub_model_output_dir ${MODEL_PATH}/${TASK_NAME}/0.75/sub/ \ + --static_sub_model ${MODEL_PATH}/${TASK_NAME}/0.75/sub_static/float \ + --n_gpu 1 --width_mult 0.75 + +done diff --git a/examples/model_compression/PP-MiniLM/pruning/export_model.py b/examples/model_compression/PP-MiniLM/pruning/export_model.py new file mode 100644 index 000000000000..ac30585852f6 --- /dev/null +++ b/examples/model_compression/PP-MiniLM/pruning/export_model.py @@ -0,0 +1,216 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging +import os +import math +import random +import time +import json +from functools import partial + +import numpy as np +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + +from paddlenlp.transformers import ErnieModel, ErnieForSequenceClassification, ErnieTokenizer +from paddlenlp.utils.log import logger +from paddleslim.nas.ofa import OFA, utils +from paddleslim.nas.ofa.convert_super import Convert, supernet +from paddleslim.nas.ofa.layers import BaseBlock + +MODEL_CLASSES = {"ernie": (ErnieForSequenceClassification, ErnieTokenizer), } + + +def ernie_forward(self, + input_ids, + token_type_ids=None, + position_ids=None, + attention_mask=None): + wtype = self.pooler.dense.fn.weight.dtype if hasattr( + self.pooler.dense, 'fn') else self.pooler.dense.weight.dtype + if attention_mask is None: + attention_mask = paddle.unsqueeze( + (input_ids == self.pad_token_id).astype(wtype) * -1e9, axis=[1, 2]) + embedding_output = self.embeddings(input_ids, token_type_ids, position_ids) + encoded_layer = self.encoder(embedding_output, attention_mask) + pooled_output = self.pooler(encoded_layer) + + return encoded_layer, pooled_output + + +ErnieModel.forward = ernie_forward + + +def parse_args(): + parser = argparse.ArgumentParser() + + # Required parameters + parser.add_argument( + "--model_type", + default=None, + type=str, + required=True, + help="Model type selected in the list: " + + ", ".join(MODEL_CLASSES.keys()), ) + parser.add_argument( + "--model_name_or_path", + default=None, + type=str, + required=True, + help="Path to pre-trained model or shortcut name selected in the list: " + + ", ".join( + sum([ + list(classes[-1].pretrained_init_configuration.keys()) + for classes in MODEL_CLASSES.values() + ], [])), ) + parser.add_argument( + "--sub_model_output_dir", + default=None, + type=str, + required=True, + help="The output directory where the sub model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--static_sub_model", + default=None, + type=str, + help="The output directory where the sub static model will be written. If set to None, not export static model", + ) + parser.add_argument( + "--max_seq_length", + default=128, + type=int, + help="The maximum total input sequence length after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded.", ) + parser.add_argument( + "--n_gpu", + type=int, + default=1, + help="number of gpus to use, 0 for cpu.") + parser.add_argument( + '--width_mult', + type=float, + default=1.0, + help="width mult you want to export") + parser.add_argument( + '--depth_mult', + type=float, + default=1.0, + help="depth mult you want to export") + args = parser.parse_args() + return args + + +def export_static_model(model, model_path, max_seq_length): + input_shape = [ + paddle.static.InputSpec( + shape=[None, max_seq_length], dtype='int64'), + paddle.static.InputSpec( + shape=[None, max_seq_length], dtype='int64') + ] + net = paddle.jit.to_static(model, input_spec=input_shape) + paddle.jit.save(net, model_path) + + +def do_train(args): + paddle.set_device("gpu" if args.n_gpu else "cpu") + args.model_type = args.model_type.lower() + model_class, tokenizer_class = MODEL_CLASSES[args.model_type] + config_path = os.path.join(args.model_name_or_path, 'model_config.json') + cfg_dict = dict(json.loads(open(config_path).read())) + + kept_layers_index = {} + if args.depth_mult < 1.0: + depth = round(cfg_dict["init_args"][0]['num_hidden_layers'] * + args.depth_mult) + cfg_dict["init_args"][0]['num_hidden_layers'] = depth + for idx, i in enumerate(range(1, depth + 1)): + kept_layers_index[idx] = math.floor(i / args.depth_mult) - 1 + + os.rename(config_path, config_path + '_bak') + with open(config_path, "w", encoding="utf-8") as f: + f.write(json.dumps(cfg_dict, ensure_ascii=False)) + + num_labels = cfg_dict['num_classes'] + + model = model_class.from_pretrained( + args.model_name_or_path, num_classes=num_labels) + + origin_model = model_class.from_pretrained( + args.model_name_or_path, num_classes=num_labels) + + os.rename(config_path + '_bak', config_path) + + sp_config = supernet(expand_ratio=[1.0, args.width_mult]) + model = Convert(sp_config).convert(model) + + ofa_model = OFA(model) + + sd = paddle.load( + os.path.join(args.model_name_or_path, 'model_state.pdparams')) + + if len(kept_layers_index) == 0: + ofa_model.model.set_state_dict(sd) + else: + for name, params in ofa_model.model.named_parameters(): + if 'encoder' not in name: + params.set_value(sd[name]) + else: + idx = int(name.strip().split('.')[3]) + mapping_name = name.replace( + '.' + str(idx) + '.', + '.' + str(kept_layers_index[idx]) + '.') + params.set_value(sd[mapping_name]) + + best_config = utils.dynabert_config(ofa_model, args.width_mult) + for name, sublayer in ofa_model.model.named_sublayers(): + if isinstance(sublayer, paddle.nn.MultiHeadAttention): + sublayer.num_heads = int(args.width_mult * sublayer.num_heads) + + origin_model_new = ofa_model.export( + best_config, + input_shapes=[[1, args.max_seq_length], [1, args.max_seq_length]], + input_dtypes=['int64', 'int64'], + origin_model=origin_model) + for name, sublayer in origin_model_new.named_sublayers(): + if isinstance(sublayer, paddle.nn.MultiHeadAttention): + sublayer.num_heads = int(args.width_mult * sublayer.num_heads) + + output_dir = os.path.join(args.sub_model_output_dir, + "model_width_%.5f" % args.width_mult) + if not os.path.exists(output_dir): + os.makedirs(output_dir) + model_to_save = origin_model_new + model_to_save.save_pretrained(output_dir) + + if args.static_sub_model != None: + export_static_model(origin_model_new, args.static_sub_model, + args.max_seq_length) + + +def print_arguments(args): + """print arguments""" + print('----------- Configuration Arguments -----------') + for arg, value in sorted(vars(args).items()): + print('%s: %s' % (arg, value)) + print('------------------------------------------------') + + +if __name__ == "__main__": + args = parse_args() + print_arguments(args) + do_train(args) diff --git a/examples/model_compression/PP-MiniLM/pruning/prune.py b/examples/model_compression/PP-MiniLM/pruning/prune.py new file mode 100644 index 000000000000..6442254a7356 --- /dev/null +++ b/examples/model_compression/PP-MiniLM/pruning/prune.py @@ -0,0 +1,445 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging +import os +import sys +import random +import time +import math +from functools import partial + +import numpy as np +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +from paddle.io import DataLoader + +from paddlenlp.data import Stack, Tuple, Pad, Dict +from paddlenlp.datasets import load_dataset +from paddlenlp.transformers import LinearDecayWithWarmup +from paddlenlp.utils.log import logger +from paddlenlp.transformers import ErnieForSequenceClassification, ErnieTokenizer, ErnieModel + +from paddleslim.nas.ofa import OFA, DistillConfig, utils +from paddleslim.nas.ofa.utils import nlp_utils +from paddleslim.nas.ofa.convert_super import Convert, supernet + +sys.path.append("../") +from data import convert_example, METRIC_CLASSES, MODEL_CLASSES + + +def parse_args(): + parser = argparse.ArgumentParser() + + # Required parameters + parser.add_argument( + "--task_name", + default=None, + type=str, + required=True, + help="The name of the task to train selected in the list: " + + ", ".join(METRIC_CLASSES.keys()), ) + parser.add_argument( + "--model_type", + default=None, + type=str, + required=True, + help="Model type selected in the list: " + + ", ".join(MODEL_CLASSES.keys()), ) + parser.add_argument( + "--model_name_or_path", + default=None, + type=str, + required=True, + help="Path to pre-trained model or shortcut name selected in the list: " + + ", ".join( + sum([ + list(classes[-1].pretrained_init_configuration.keys()) + for classes in MODEL_CLASSES.values() + ], [])), ) + parser.add_argument( + "--output_dir", + default=None, + type=str, + required=True, + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--glue_dir", + default="/root/.paddlenlp/datasets/Clue/", + type=str, + required=False, + help="The Glue directory.", ) + parser.add_argument( + "--max_seq_length", + default=128, + type=int, + help="The maximum total input sequence length after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded.", ) + parser.add_argument( + "--batch_size", + default=8, + type=int, + help="Batch size per GPU/CPU for training.", ) + parser.add_argument( + "--learning_rate", + default=5e-5, + type=float, + help="The initial learning rate for Adam.") + parser.add_argument( + "--weight_decay", + default=0.0, + type=float, + help="Weight decay if we apply some.") + parser.add_argument( + "--adam_epsilon", + default=1e-8, + type=float, + help="Epsilon for Adam optimizer.") + parser.add_argument( + "--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument( + "--lambda_logit", + default=1.0, + type=float, + help="lambda for logit loss.") + parser.add_argument( + "--num_train_epochs", + default=3, + type=int, + help="Total number of training epochs to perform.", ) + parser.add_argument( + "--max_steps", + default=-1, + type=int, + help="If > 0: set total number of training steps to perform. Override num_train_epochs.", + ) + parser.add_argument( + "--warmup_steps", + default=0, + type=int, + help="Linear warmup over warmup_steps.") + parser.add_argument( + "--warmup_proportion", + default=0.1, + type=float, + help="Linear warmup proportion over total steps.") + parser.add_argument( + "--logging_steps", + type=int, + default=500, + help="Log every X updates steps.") + parser.add_argument( + "--save_steps", + type=int, + default=500, + help="Save checkpoint every X updates steps.") + parser.add_argument( + "--seed", type=int, default=42, help="random seed for initialization") + parser.add_argument( + "--device", + default="gpu", + type=str, + choices=["gpu", "cpu", "xpu"], + help="The device to select to train the model, is must be cpu/gpu/xpu.") + parser.add_argument( + '--width_mult_list', + nargs='+', + type=float, + default=[1.0, 5 / 6, 2 / 3, 0.5], + help="width mult in compress") + args = parser.parse_args() + return args + + +def set_seed(args): + # Use the same data seed(for data shuffle) for all procs to guarantee data + # consistency after sharding. + random.seed(args.seed) + np.random.seed(args.seed) + # Maybe different op seeds(for dropout) for different procs is better. By: + # `paddle.seed(args.seed + paddle.distributed.get_rank())` + paddle.seed(args.seed) + + +@paddle.no_grad() +def evaluate(model, metric, data_loader, width_mult, student=False): + model.eval() + metric.reset() + for i, batch in enumerate(data_loader): + input_ids, segment_ids, labels = batch + logits = model(input_ids, segment_ids, attention_mask=[None, None]) + if isinstance(logits, tuple): + logits = logits[0] + correct = metric.compute(logits, labels) + metric.update(correct) + + res = metric.accumulate() + print("width_mult: %s, acc: %s, " % (str(width_mult), res), end='') + model.train() + return res + + +### monkey patch for bert forward to accept [attention_mask, head_mask] as attention_mask +def ernie_forward(self, + input_ids, + token_type_ids=None, + position_ids=None, + attention_mask=[None, None]): + wtype = self.pooler.dense.fn.weight.dtype if hasattr( + self.pooler.dense, 'fn') else self.pooler.dense.weight.dtype + if attention_mask[0] is None: + attention_mask[0] = paddle.unsqueeze( + (input_ids == self.pad_token_id).astype(wtype) * -1e9, axis=[1, 2]) + embedding_output = self.embeddings(input_ids, token_type_ids, position_ids) + encoded_layer = self.encoder(embedding_output, attention_mask) + pooled_output = self.pooler(encoded_layer) + + return encoded_layer, pooled_output + + +ErnieModel.forward = ernie_forward + + +### reorder weights according head importance and neuron importance +def reorder_neuron_head(model, head_importance, neuron_importance): + # reorder heads and ffn neurons + for layer, current_importance in enumerate(neuron_importance): + # reorder heads + idx = paddle.argsort(head_importance[layer], descending=True) + nlp_utils.reorder_head(model.ernie.encoder.layers[layer].self_attn, idx) + # reorder neurons + idx = paddle.argsort( + paddle.to_tensor(current_importance), descending=True) + nlp_utils.reorder_neuron( + model.ernie.encoder.layers[layer].linear1.fn, idx, dim=1) + nlp_utils.reorder_neuron( + model.ernie.encoder.layers[layer].linear2.fn, idx, dim=0) + + +def soft_cross_entropy(inp, target): + inp_likelihood = F.log_softmax(inp, axis=-1) + target_prob = F.softmax(target, axis=-1) + return -1. * paddle.mean(paddle.sum(inp_likelihood * target_prob, axis=-1)) + + +def do_train(args): + paddle.set_device(args.device) + if paddle.distributed.get_world_size() > 1: + paddle.distributed.init_parallel_env() + + set_seed(args) + + args.task_name = args.task_name.lower() + metric_class = METRIC_CLASSES[args.task_name] + args.model_type = args.model_type.lower() + model_class, tokenizer_class = MODEL_CLASSES[args.model_type] + train_ds = load_dataset('clue', args.task_name, splits='train') + tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path) + + trans_func = partial( + convert_example, + tokenizer=tokenizer, + label_list=train_ds.label_list, + max_seq_length=args.max_seq_length) + train_ds = train_ds.map(trans_func, lazy=True) + train_batch_sampler = paddle.io.DistributedBatchSampler( + train_ds, batch_size=args.batch_size, shuffle=True) + batchify_fn = lambda samples, fn=Tuple( + Pad(axis=0, pad_val=tokenizer.pad_token_id), # input + Pad(axis=0, pad_val=tokenizer.pad_token_type_id), # segment + Stack(dtype="int64" if train_ds.label_list else "float32") # label + ): fn(samples) + + train_data_loader = DataLoader( + dataset=train_ds, + batch_sampler=train_batch_sampler, + collate_fn=batchify_fn, + num_workers=0, + return_list=True) + + dev_ds = load_dataset('clue', args.task_name, splits='dev') + dev_ds = dev_ds.map(trans_func, lazy=True) + dev_batch_sampler = paddle.io.BatchSampler( + dev_ds, batch_size=args.batch_size, shuffle=False) + dev_data_loader = DataLoader( + dataset=dev_ds, + batch_sampler=dev_batch_sampler, + collate_fn=batchify_fn, + num_workers=0, + return_list=True) + num_labels = 1 if train_ds.label_list == None else len(train_ds.label_list) + + model = model_class.from_pretrained( + args.model_name_or_path, num_classes=num_labels) + + # Step1: Initialize a dictionary to save the weights from the origin BERT model. + origin_weights = model.state_dict() + + # Step2: Convert origin model to supernet. + sp_config = supernet(expand_ratio=[1.0]) + model = Convert(sp_config).convert(model) + # Use weights saved in the dictionary to initialize supernet. + utils.set_state_dict(model, origin_weights) + del origin_weights + + super_sd = paddle.load( + os.path.join(args.model_name_or_path, 'model_state.pdparams')) + model.set_state_dict(super_sd) + + # Step3: Define teacher model. + teacher_model = model_class.from_pretrained( + args.model_name_or_path, num_classes=num_labels) + + # Step4: Config about distillation. + mapping_layers = ['ernie.embeddings'] + for idx in range(model.ernie.config['num_hidden_layers']): + mapping_layers.append('ernie.encoder.layers.{}'.format(idx)) + + default_distill_config = { + 'lambda_distill': 0.1, + 'teacher_model': teacher_model, + 'mapping_layers': mapping_layers, + } + distill_config = DistillConfig(**default_distill_config) + + # Step5: Config in supernet training. + ofa_model = OFA(model, + distill_config=distill_config, + elastic_order=['width']) + + criterion = paddle.nn.loss.CrossEntropyLoss( + ) if train_ds.label_list else paddle.nn.loss.MSELoss() + + metric = metric_class() + + #### Step6: Calculate the importance of neurons and head, + #### and then reorder them according to the importance. + head_importance, neuron_importance = nlp_utils.compute_neuron_head_importance( + args.task_name, + ofa_model.model, + dev_data_loader, + loss_fct=criterion, + num_layers=model.ernie.config['num_hidden_layers'], + num_heads=model.ernie.config['num_attention_heads']) + reorder_neuron_head(ofa_model.model, head_importance, neuron_importance) + + if paddle.distributed.get_world_size() > 1: + ofa_model.model = paddle.DataParallel(ofa_model.model) + + if args.max_steps > 0: + num_training_steps = args.max_steps + num_train_epochs = math.ceil(num_training_steps / + len(train_data_loader)) + else: + num_training_steps = len(train_data_loader) * args.num_train_epochs + num_train_epochs = args.num_train_epochs + + warmup = args.warmup_steps if args.warmup_steps > 0 else args.warmup_proportion + + lr_scheduler = LinearDecayWithWarmup(args.learning_rate, num_training_steps, + warmup) + + # Generate parameter names needed to perform weight decay. + # All bias and LayerNorm parameters are excluded. + decay_params = [ + p.name for n, p in model.named_parameters() + if not any(nd in n for nd in ["bias", "norm"]) + ] + + optimizer = paddle.optimizer.AdamW( + learning_rate=lr_scheduler, + beta1=0.9, + beta2=0.999, + epsilon=args.adam_epsilon, + parameters=model.parameters(), + weight_decay=args.weight_decay, + apply_decay_param_fun=lambda x: x in decay_params, + grad_clip=nn.ClipGradByGlobalNorm(args.max_grad_norm)) + + global_step = 0 + tic_train = time.time() + best_res = 0.0 + for epoch in range(num_train_epochs): + # Step7: Set current epoch and task. + ofa_model.set_epoch(epoch) + ofa_model.set_task('width') + + for step, batch in enumerate(train_data_loader): + global_step += 1 + input_ids, segment_ids, labels = batch + + for width_mult in args.width_mult_list: + # Step8: Broadcast supernet config from width_mult, + # and use this config in supernet training. + net_config = utils.dynabert_config(ofa_model, width_mult) + ofa_model.set_net_config(net_config) + logits, teacher_logits = ofa_model( + input_ids, segment_ids, attention_mask=[None, None]) + rep_loss = ofa_model.calc_distill_loss() + logit_loss = soft_cross_entropy(logits, teacher_logits.detach()) + loss = rep_loss + args.lambda_logit * logit_loss + loss.backward() + optimizer.step() + lr_scheduler.step() + optimizer.clear_grad() + + if global_step % args.logging_steps == 0: + logger.info( + "global step %d, epoch: %d, batch: %d, loss: %f, speed: %.2f step/s" + % (global_step, epoch, step, loss, + args.logging_steps / (time.time() - tic_train))) + tic_train = time.time() + + if global_step % args.save_steps == 0 or global_step == num_training_steps: + tic_eval = time.time() + evaluate(teacher_model, metric, dev_data_loader, width_mult=100) + print("eval done total : %s s" % (time.time() - tic_eval)) + for idx, width_mult in enumerate(args.width_mult_list): + net_config = utils.dynabert_config(ofa_model, width_mult) + ofa_model.set_net_config(net_config) + tic_eval = time.time() + res = evaluate(ofa_model, metric, dev_data_loader, + width_mult) + print("eval done total : %s s" % (time.time() - tic_eval)) + + if best_res < res: + output_dir = args.output_dir + if not os.path.exists(output_dir): + os.makedirs(output_dir) + # need better way to get inner model of DataParallel + model_to_save = model._layers if isinstance( + model, paddle.DataParallel) else model + model_to_save.save_pretrained(output_dir) + tokenizer.save_pretrained(output_dir) + best_res = res + if global_step >= num_training_steps: + return + print("best_res: ", best_res) + + +def print_arguments(args): + """print arguments""" + print('----------- Configuration Arguments -----------') + for arg, value in sorted(vars(args).items()): + print('%s: %s' % (arg, value)) + print('------------------------------------------------') + + +if __name__ == "__main__": + args = parse_args() + print_arguments(args) + do_train(args) diff --git a/examples/model_compression/PP-MiniLM/pruning/prune.sh b/examples/model_compression/PP-MiniLM/pruning/prune.sh new file mode 100644 index 000000000000..51e196909348 --- /dev/null +++ b/examples/model_compression/PP-MiniLM/pruning/prune.sh @@ -0,0 +1,35 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +export TASK_NAME=$1 +export LR=$2 +export BATCH_SIZE=$3 +export PRE_EPOCHS=$4 +export SEQ_LEN=$5 +export CUDA_VISIBLE_DEVICES=$6 +export STUDENT_DIR=$7 +export WIDTH_LIST=$8 + +python -u ./prune.py --model_type ernie \ + --model_name_or_path ${STUDENT_DIR} \ + --task_name $TASK_NAME --max_seq_length ${SEQ_LEN} \ + --batch_size ${BATCH_SIZE} \ + --learning_rate ${LR} \ + --num_train_epochs ${PRE_EPOCHS} \ + --logging_steps 100 \ + --save_steps 100 \ + --output_dir ./pruned_models/$TASK_NAME/0.75/best_model/ \ + --device gpu \ + --width_mult_list ${WIDTH_LIST} + diff --git a/examples/model_compression/PP-MiniLM/quantization/quant_all.sh b/examples/model_compression/PP-MiniLM/quantization/quant_all.sh new file mode 100644 index 000000000000..1b39c8ca0e5d --- /dev/null +++ b/examples/model_compression/PP-MiniLM/quantization/quant_all.sh @@ -0,0 +1,20 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +MODEL_DIR=../pruning/pruned_models/ + +for task in AFQMC TNEWS IFLYTEK CMNLI OCNLI CLUEWSC2020 CSL +do + python quant_post.py --task_name ${task} --input_dir ${MODEL_DIR}/${task}/0.75/sub_static +done diff --git a/examples/model_compression/PP-MiniLM/quantization/quant_post.py b/examples/model_compression/PP-MiniLM/quantization/quant_post.py new file mode 100644 index 000000000000..df926f8b5fc8 --- /dev/null +++ b/examples/model_compression/PP-MiniLM/quantization/quant_post.py @@ -0,0 +1,130 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import six +import sys +import os +import time +import argparse +from functools import partial + +import numpy as np +import paddle + +import paddleslim +from paddlenlp.data import Stack, Tuple, Pad, Dict +from paddlenlp.datasets import load_dataset +from paddlenlp.transformers import ErnieTokenizer + +sys.path.append("../") +from data import convert_example, METRIC_CLASSES, MODEL_CLASSES + +parser = argparse.ArgumentParser() + +parser.add_argument( + "--task_name", type=str, default="afqmc", required=False, help="task_name") +parser.add_argument( + "--input_dir", + type=str, + default="afqmc", + required=False, + help="Input task model directory.") + +parser.add_argument( + "--save_model_filename", + type=str, + default="int8.pdmodel", + required=False, + help="File name of quantified model.") + +parser.add_argument( + "--save_params_filename", + type=str, + default="int8.pdiparams", + required=False, + help="File name of quantified model's parameters.") + +parser.add_argument( + "--input_model_filename", + type=str, + default="float.pdmodel", + required=False, + help="File name of float model.") + +parser.add_argument( + "--input_param_filename", + type=str, + default="float.pdiparams", + required=False, + help="File name of float model's parameters.") + +parser.add_argument( + "--model_name_or_path", + default='ppminilm-6l-768h', + type=str, + help="Model name or the directory of model directory.", ) + +args = parser.parse_args() + + +def quant_post(args, batch_size=8, algo='avg'): + place = paddle.set_device("gpu") + exe = paddle.static.Executor(place) + args.task_name = args.task_name.lower() + + train_ds = load_dataset("clue", args.task_name, splits="dev") + + tokenizer = ErnieTokenizer.from_pretrained(args.model_name_or_path) + + trans_func = partial( + convert_example, + tokenizer=tokenizer, + label_list=train_ds.label_list, + max_seq_length=128, + is_test=True) + train_ds = train_ds.map(trans_func, lazy=True) + + def batch_generator_func(): + batch_data = [[], []] + for data in train_ds: + batch_data[0].append(data[0]) + batch_data[1].append(data[1]) + if len(batch_data[0]) == batch_size: + input_ids = Pad(axis=0, pad_val=0)(batch_data[0]) + segment_ids = Pad(axis=0, pad_val=0)(batch_data[1]) + yield [input_ids, segment_ids] + batch_data = [[], []] + + paddleslim.quant.quant_post_static( + exe, + args.input_dir, + os.path.join(args.task_name + '_quant_models', algo + str(batch_size)), + save_model_filename=args.save_model_filename, + save_params_filename=args.save_params_filename, + algo=algo, + hist_percent=0.9999, + batch_generator=batch_generator_func, + model_filename=args.input_model_filename, + params_filename=args.input_param_filename, + quantizable_op_type=['matmul', 'matmul_v2'], + weight_bits=8, + weight_quantize_type='channel_wise_abs_max', + batch_nums=1, ) + + +if __name__ == '__main__': + paddle.enable_static() + for batch_size in [4, 8]: + for algo in ['abs_max', 'avg', 'mse', 'hist']: + quant_post(args, batch_size, algo) diff --git a/paddlenlp/transformers/distill_utils.py b/paddlenlp/transformers/distill_utils.py index 037cd8ceb849..3f67c0d022b1 100644 --- a/paddlenlp/transformers/distill_utils.py +++ b/paddlenlp/transformers/distill_utils.py @@ -140,7 +140,6 @@ def calc_multi_relation_loss(loss_fct, def calc_minilm_loss(loss_fct, s, t, attn_mask, num_relation_heads=0): """ Calculates loss for Q-Q, K-K, V-V relation from MiniLMv2. - Args: loss_fct (callable): Loss function for distillation. It only supports kl_div loss now. @@ -197,7 +196,6 @@ def to_distill(self, expose attributes `outputs.q`, `outputs.k`, `outputs.v`, `outputs.scaled_qks`, `outputs.hidden_states`and `outputs.attentions` of the object for distillation. - It could be returned intermediate tensor using in MiniLM and TinyBERT strategy. """ @@ -435,4 +433,4 @@ def bert_forward(self, input_ids, token_type_ids=None, attention_mask=None): sequence_output, pooled_output = model(input_ids, token_type_ids, attention_mask) - return encoder.attentions, encoder.hidden_states + return encoder.attentions, encoder.hidden_states \ No newline at end of file diff --git a/paddlenlp/transformers/ernie/modeling.py b/paddlenlp/transformers/ernie/modeling.py index f0735f1c44ff..78c040a88d81 100644 --- a/paddlenlp/transformers/ernie/modeling.py +++ b/paddlenlp/transformers/ernie/modeling.py @@ -168,6 +168,20 @@ class ErniePretrainedModel(PretrainedModel): "vocab_size": 30522, "pad_token_id": 0, }, + "ppminilm-6l-768h": { + "attention_probs_dropout_prob": 0.1, + "intermediate_size": 3072, + "hidden_act": "relu", + "hidden_dropout_prob": 0.1, + "hidden_size": 768, + "initializer_range": 0.02, + "max_position_embeddings": 512, + "num_attention_heads": 12, + "num_hidden_layers": 6, + "type_vocab_size": 4, + "vocab_size": 21128, + "pad_token_id": 0, + }, } resource_files_names = {"model_state": "model_state.pdparams"} pretrained_resource_files_map = { @@ -182,6 +196,8 @@ class ErniePretrainedModel(PretrainedModel): "https://bj.bcebos.com/paddlenlp/models/transformers/ernie_v2_base/ernie_v2_eng_base_finetuned_squad.pdparams", "ernie-2.0-large-en": "https://bj.bcebos.com/paddlenlp/models/transformers/ernie_v2_large/ernie_v2_eng_large.pdparams", + "ppminilm-6l-768h": + "https://bj.bcebos.com/paddlenlp/models/transformers/ppminilm-6l-768h/ppminilm-6l-768h.pdparams", } } base_model_prefix = "ernie" diff --git a/paddlenlp/transformers/ernie/tokenizer.py b/paddlenlp/transformers/ernie/tokenizer.py index 589c64d26881..13686910b702 100644 --- a/paddlenlp/transformers/ernie/tokenizer.py +++ b/paddlenlp/transformers/ernie/tokenizer.py @@ -93,6 +93,8 @@ class ErnieTokenizer(PretrainedTokenizer): "https://bj.bcebos.com/paddlenlp/models/transformers/ernie-gen-large/vocab.txt", "ernie-gen-large-430g-en": "https://bj.bcebos.com/paddlenlp/models/transformers/ernie-gen-large-430g/vocab.txt", + "ppminilm-6l-768h": + "https://bj.bcebos.com/paddlenlp/models/transformers/ppminilm-6l-768h/vocab.txt", } } pretrained_init_configuration = { @@ -120,6 +122,9 @@ class ErnieTokenizer(PretrainedTokenizer): "ernie-gen-large-430g-en": { "do_lower_case": True }, + "ppminilm-6l-768h": { + "do_lower_case": True + }, } def __init__(self,