diff --git a/community/junnyu/ckiplab-bert-base-chinese-ner/README.md b/community/junnyu/ckiplab-bert-base-chinese-ner/README.md index 49f895f67da3..12b2bea08893 100644 --- a/community/junnyu/ckiplab-bert-base-chinese-ner/README.md +++ b/community/junnyu/ckiplab-bert-base-chinese-ner/README.md @@ -4,6 +4,7 @@ 关于完整使用方法及其他信息,请参考 https://github.com/ckiplab/ckip-transformers 。 **模型结构**: **`BertForTokenClassification`**,带有token分类头的Bert模型。 + **适用下游任务**:**命名实体识别**,该权重已经在下游`NER`任务上进行了微调,因此可直接使用。 # 使用示例 diff --git a/community/junnyu/ckiplab-bert-base-chinese-pos/README.md b/community/junnyu/ckiplab-bert-base-chinese-pos/README.md index 582b3024f43c..62cf10495779 100644 --- a/community/junnyu/ckiplab-bert-base-chinese-pos/README.md +++ b/community/junnyu/ckiplab-bert-base-chinese-pos/README.md @@ -4,6 +4,7 @@ 关于完整使用方法及其他信息,请参考 https://github.com/ckiplab/ckip-transformers 。 **模型结构**: **`BertForTokenClassification`**,带有token分类头的Bert模型。 + **适用下游任务**:**词性标注**,该权重已经在下游`POS`任务上进行了微调,因此可直接使用。 # 使用示例 diff --git a/community/junnyu/ckiplab-bert-base-chinese-ws/README.md b/community/junnyu/ckiplab-bert-base-chinese-ws/README.md index 99061c156617..475d3494845d 100644 --- a/community/junnyu/ckiplab-bert-base-chinese-ws/README.md +++ b/community/junnyu/ckiplab-bert-base-chinese-ws/README.md @@ -4,6 +4,7 @@ 关于完整使用方法及其他信息,请参考 https://github.com/ckiplab/ckip-transformers 。 **模型结构**: **`BertForTokenClassification`**,带有token分类头的Bert模型。 + **适用下游任务**:**分词**,该权重已经在下游`WS`任务上进行了微调,因此可直接使用。 # 使用示例 diff --git a/community/junnyu/electra_compare.py b/community/junnyu/electra_compare.py new file mode 100644 index 000000000000..76e93ca6aee6 --- /dev/null +++ b/community/junnyu/electra_compare.py @@ -0,0 +1,138 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import torch +import numpy as np +import paddlenlp.transformers as ppnlp +import transformers as hgnlp + + +def compare(a, b): + a = a.cpu().numpy() + b = b.cpu().numpy() + meandif = np.abs(a - b).mean() + maxdif = np.abs(a - b).max() + print("mean dif:", meandif) + print("max dif:", maxdif) + + +def compare_discriminator( + path="junnyu/hfl-chinese-electra-180g-base-discriminator"): + pdmodel = ppnlp.ElectraDiscriminator.from_pretrained(path) + ptmodel = ppnlp.ElectraForPreTraining.from_pretrained(path).cuda() + tokenizer = ppnlp.ElectraTokenizer.from_pretrained(path) + pdmodel.eval() + ptmodel.eval() + text = "欢迎使用paddlenlp!" + pdinputs = { + k: paddle.to_tensor( + v, dtype="int64").unsqueeze(0) + for k, v in tokenizer(text).items() + } + ptinputs = { + k: torch.tensor( + v, dtype=torch.long).unsqueeze(0).cuda() + for k, v in tokenizer(text).items() + } + with paddle.no_grad(): + pd_logits = pdmodel(**pdinputs) + + with torch.no_grad(): + pt_logits = ptmodel(**ptinputs).logits + + compare(pd_logits, pt_logits) + + +def compare_generator(): + text = "本院经审查认为,本案[MASK]民间借贷纠纷申请再审案件,应重点审查二审判决是否存在错误的情形。" + # ppnlp + path = "junnyu/hfl-chinese-legal-electra-small-generator" + model = ppnlp.ElectraForMaskedLM.from_pretrained(path) + tokenizer = ppnlp.ElectraTokenizer.from_pretrained(path) + model.eval() + tokens = ["[CLS]"] + text_list = text.split("[MASK]") + for i, t in enumerate(text_list): + tokens.extend(tokenizer.tokenize(t)) + if i == len(text_list) - 1: + tokens.extend(["[SEP]"]) + else: + tokens.extend(["[MASK]"]) + + input_ids_list = tokenizer.convert_tokens_to_ids(tokens) + input_ids = paddle.to_tensor([input_ids_list]) + with paddle.no_grad(): + pd_outputs = model(input_ids)[0] + pd_outputs_sentence = "paddle: " + for i, id in enumerate(input_ids_list): + if id == tokenizer.convert_tokens_to_ids(["[MASK]"])[0]: + scores, index = paddle.nn.functional.softmax(pd_outputs[i], + -1).topk(5) + tokens = tokenizer.convert_ids_to_tokens(index.tolist()) + outputs = [] + for score, tk in zip(scores.tolist(), tokens): + outputs.append(f"{tk}={score}") + pd_outputs_sentence += "[" + "||".join(outputs) + "]" + " " + else: + pd_outputs_sentence += "".join( + tokenizer.convert_ids_to_tokens( + [id], skip_special_tokens=True)) + " " + + print(pd_outputs_sentence) + + # transformers + path = "hfl/chinese-legal-electra-small-generator" + config = hgnlp.ElectraConfig.from_pretrained(path) + config.hidden_size = 64 + config.intermediate_size = 256 + config.num_attention_heads = 1 + model = hgnlp.ElectraForMaskedLM.from_pretrained(path, config=config) + tokenizer = hgnlp.ElectraTokenizer.from_pretrained(path) + model.eval() + + inputs = tokenizer(text, return_tensors="pt") + + with torch.no_grad(): + pt_outputs = model(**inputs).logits[0] + pt_outputs_sentence = "pytorch: " + for i, id in enumerate(inputs["input_ids"][0].tolist()): + if id == tokenizer.convert_tokens_to_ids(["[MASK]"])[0]: + scores, index = torch.nn.functional.softmax(pt_outputs[i], + -1).topk(5) + tokens = tokenizer.convert_ids_to_tokens(index.tolist()) + outputs = [] + for score, tk in zip(scores.tolist(), tokens): + outputs.append(f"{tk}={score}") + pt_outputs_sentence += "[" + "||".join(outputs) + "]" + " " + else: + pt_outputs_sentence += "".join( + tokenizer.convert_ids_to_tokens( + [id], skip_special_tokens=True)) + " " + + print(pt_outputs_sentence) + + +if __name__ == "__main__": + compare_discriminator( + path="junnyu/hfl-chinese-electra-180g-base-discriminator") + # # mean dif: 3.1698835e-06 + # # max dif: 1.335144e-05 + compare_discriminator( + path="junnyu/hfl-chinese-electra-180g-small-ex-discriminator") + # mean dif: 3.7930229e-06 + # max dif: 1.04904175e-05 + compare_generator() + # paddle: 本 院 经 审 查 认 为 , 本 案 [因=0.27444931864738464||经=0.18613006174564362||系=0.09408623725175858||的=0.07536833733320236||就=0.033634234219789505] 民 间 借 贷 纠 纷 申 请 再 审 案 件 , 应 重 点 审 查 二 审 判 决 是 否 存 在 错 误 的 情 形 。 + # pytorch: 本 院 经 审 查 认 为 , 本 案 [因=0.2744344472885132||经=0.1861187219619751||系=0.09407979995012283||的=0.07537488639354706||就=0.03363779932260513] 民 间 借 贷 纠 纷 申 请 再 审 案 件 , 应 重 点 审 查 二 审 判 决 是 否 存 在 错 误 的 情 形 。 diff --git a/community/junnyu/electra_convert_huggingface2paddle.py b/community/junnyu/electra_convert_huggingface2paddle.py new file mode 100644 index 000000000000..cbccca4917e2 --- /dev/null +++ b/community/junnyu/electra_convert_huggingface2paddle.py @@ -0,0 +1,79 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import OrderedDict +import argparse + +huggingface_to_paddle = { + "embeddings.LayerNorm": "embeddings.layer_norm", + "encoder.layer": "encoder.layers", + "attention.self.query.": "self_attn.q_proj.", + "attention.self.key.": "self_attn.k_proj.", + "attention.self.value.": "self_attn.v_proj.", + "attention.output.dense.": "self_attn.out_proj.", + "intermediate.dense": "linear1", + "output.dense": "linear2", + "attention.output.LayerNorm": "norm1", + "output.LayerNorm": "norm2", + "generator_predictions.LayerNorm": "generator_predictions.layer_norm", + "generator_lm_head.bias": "generator_lm_head_bias", +} + +skip_weights = ["electra.embeddings.position_ids"] +dont_transpose = ["_embeddings.weight", "LayerNorm."] + + +def convert_pytorch_checkpoint_to_paddle(pytorch_checkpoint_path, + paddle_dump_path): + import torch + import paddle + pytorch_state_dict = torch.load(pytorch_checkpoint_path, map_location="cpu") + paddle_state_dict = OrderedDict() + for k, v in pytorch_state_dict.items(): + if k == "generator_lm_head.weight": continue + is_transpose = False + if k in skip_weights: + continue + if k[-7:] == ".weight": + if not any([w in k for w in dont_transpose]): + if v.ndim == 2: + v = v.transpose(0, 1) + is_transpose = True + oldk = k + for huggingface_name, paddle_name in huggingface_to_paddle.items(): + k = k.replace(huggingface_name, paddle_name) + + print(f"Converting: {oldk} => {k} | is_transpose {is_transpose}") + paddle_state_dict[k] = v.data.numpy() + + paddle.save(paddle_state_dict, paddle_dump_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--pytorch_checkpoint_path", + default=r"MODEL\hfl-chinese-electra-180g-base-discriminator\pytorch_model.bin", + type=str, + required=False, + help="Path to the Pytorch checkpoint path.") + parser.add_argument( + "--paddle_dump_path", + default=r"MODEL\hfl-chinese-electra-180g-base-discriminator\model_state.pdparams", + type=str, + required=False, + help="Path to the output Paddle model.") + args = parser.parse_args() + convert_pytorch_checkpoint_to_paddle(args.pytorch_checkpoint_path, + args.paddle_dump_path) diff --git a/community/junnyu/hfl-chinese-electra-180g-base-discriminator/README.md b/community/junnyu/hfl-chinese-electra-180g-base-discriminator/README.md new file mode 100644 index 000000000000..12b229858b0c --- /dev/null +++ b/community/junnyu/hfl-chinese-electra-180g-base-discriminator/README.md @@ -0,0 +1,37 @@ +# 详细介绍 +**介绍**:该模型是base版本的Electra discriminator模型,并且在180G的中文数据上进行训练。 + +**模型结构**: **`ElectraDiscriminator`**,带有判别器的中文Electra模型。 + +**适用下游任务**:**通用下游任务**,如:句子级别分类,token级别分类,抽取式问答等任务。 + +# 使用示例 + +```python +import paddle +from paddlenlp.transformers import ElectraDiscriminator, ElectraTokenizer + +path = "junnyu/hfl-chinese-electra-180g-base-discriminator" +model = ElectraDiscriminator.from_pretrained(path) +tokenizer = ElectraTokenizer.from_pretrained(path) +model.eval() + +text = "欢迎使用paddlenlp!" +inputs = { + k: paddle.to_tensor( + v, dtype="int64").unsqueeze(0) + for k, v in tokenizer(text).items() +} + +with paddle.no_grad(): + logits = model(**inputs) + +print(logits.shape) + +``` + +# 权重来源 + +https://huggingface.co/hfl/chinese-electra-180g-base-discriminator +谷歌和斯坦福大学发布了一种名为 ELECTRA 的新预训练模型,与 BERT 及其变体相比,该模型具有非常紧凑的模型尺寸和相对具有竞争力的性能。 为进一步加快中文预训练模型的研究,HIT与科大讯飞联合实验室(HFL)发布了基于ELECTRA官方代码的中文ELECTRA模型。 与 BERT 及其变体相比,ELECTRA-small 只需 1/10 的参数就可以在几个 NLP 任务上达到相似甚至更高的分数。 +这个项目依赖于官方ELECTRA代码: https://github.com/google-research/electra diff --git a/community/junnyu/hfl-chinese-electra-180g-base-discriminator/files.json b/community/junnyu/hfl-chinese-electra-180g-base-discriminator/files.json new file mode 100644 index 000000000000..1f0b002a2635 --- /dev/null +++ b/community/junnyu/hfl-chinese-electra-180g-base-discriminator/files.json @@ -0,0 +1,6 @@ +{ + "model_config_file": "https://paddlenlp.bj.bcebos.com/models/transformers/community/junnyu/hfl-chinese-electra-180g-base-discriminator/model_config.json", + "model_state": "https://paddlenlp.bj.bcebos.com/models/transformers/community/junnyu/hfl-chinese-electra-180g-base-discriminator/model_state.pdparams", + "tokenizer_config_file": "https://paddlenlp.bj.bcebos.com/models/transformers/community/junnyu/hfl-chinese-electra-180g-base-discriminator/tokenizer_config.json", + "vocab_file": "https://paddlenlp.bj.bcebos.com/models/transformers/community/junnyu/hfl-chinese-electra-180g-base-discriminator/vocab.txt" +} diff --git a/community/junnyu/hfl-chinese-electra-180g-small-ex-discriminator/README.md b/community/junnyu/hfl-chinese-electra-180g-small-ex-discriminator/README.md new file mode 100644 index 000000000000..633b0cf57ce2 --- /dev/null +++ b/community/junnyu/hfl-chinese-electra-180g-small-ex-discriminator/README.md @@ -0,0 +1,36 @@ +# 详细介绍 +**介绍**:该模型是small版本的Electra discriminator模型,并且在180G的中文数据上进行训练。 + +**模型结构**: **`ElectraDiscriminator`**,带有判别器的中文Electra模型。 + +**适用下游任务**:**通用下游任务**,如:句子级别分类,token级别分类,抽取式问答等任务。 + +# 使用示例 + +```python +import paddle +from paddlenlp.transformers import ElectraDiscriminator,ElectraTokenizer + +path = "junnyu/hfl-chinese-electra-180g-small-ex-discriminator" +model = ElectraDiscriminator.from_pretrained(path) +tokenizer = ElectraTokenizer.from_pretrained(path) +model.eval() + +text = "欢迎使用paddlenlp!" +inputs = { + k: paddle.to_tensor( + v, dtype="int64").unsqueeze(0) + for k, v in tokenizer(text).items() +} + +with paddle.no_grad(): + logits = model(**inputs) + +print(logits.shape) + +``` + +# 权重来源 + +https://huggingface.co/hfl/chinese-electra-180g-small-ex-discriminator +谷歌和斯坦福大学发布了一种名为 ELECTRA 的新预训练模型,与 BERT 及其变体相比,该模型具有非常紧凑的模型尺寸和相对具有竞争力的性能。 为进一步加快中文预训练模型的研究,HIT与科大讯飞联合实验室(HFL)发布了基于ELECTRA官方代码的中文ELECTRA模型。 与 BERT 及其变体相比,ELECTRA-small 只需 1/10 的参数就可以在几个 NLP 任务上达到相似甚至更高的分数。 diff --git a/community/junnyu/hfl-chinese-electra-180g-small-ex-discriminator/files.json b/community/junnyu/hfl-chinese-electra-180g-small-ex-discriminator/files.json new file mode 100644 index 000000000000..c0e2a2eeba99 --- /dev/null +++ b/community/junnyu/hfl-chinese-electra-180g-small-ex-discriminator/files.json @@ -0,0 +1,6 @@ +{ + "model_config_file": "https://paddlenlp.bj.bcebos.com/models/transformers/community/junnyu/hfl-chinese-electra-180g-small-ex-discriminator/model_config.json", + "model_state": "https://paddlenlp.bj.bcebos.com/models/transformers/community/junnyu/hfl-chinese-electra-180g-small-ex-discriminator/model_state.pdparams", + "tokenizer_config_file": "https://paddlenlp.bj.bcebos.com/models/transformers/community/junnyu/hfl-chinese-electra-180g-small-ex-discriminator/tokenizer_config.json", + "vocab_file": "https://paddlenlp.bj.bcebos.com/models/transformers/community/junnyu/hfl-chinese-electra-180g-small-ex-discriminator/vocab.txt" +} diff --git a/community/junnyu/hfl-chinese-legal-electra-small-generator/README.md b/community/junnyu/hfl-chinese-legal-electra-small-generator/README.md new file mode 100644 index 000000000000..535b10c4ccc4 --- /dev/null +++ b/community/junnyu/hfl-chinese-legal-electra-small-generator/README.md @@ -0,0 +1,58 @@ +# 详细介绍 +**介绍**:该模型是small版本的Electra generator模型,该模型在法律领域数据上进行了预训练。 + +**模型结构**: **`ElectraGenerator`**,带有生成器的中文Electra模型。 + +**适用下游任务**:**法律领域的下游任务**,如:法律领域的句子级别分类,法律领域的token级别分类,法律领域的抽取式问答等任务。 +(注:生成器的效果不好,通常使用判别器进行下游任务微调) + + +# 使用示例 + +```python +import paddle +from paddlenlp.transformers import ElectraGenerator, ElectraTokenizer + +text = "本院经审查认为,本案[MASK]民间借贷纠纷申请再审案件,应重点审查二审判决是否存在错误的情形。" +path = "junnyu/hfl-chinese-legal-electra-small-generator" +model = ElectraGenerator.from_pretrained(path) +model.eval() +tokenizer = ElectraTokenizer.from_pretrained(path) + +tokens = ["[CLS]"] +text_list = text.split("[MASK]") +for i, t in enumerate(text_list): + tokens.extend(tokenizer.tokenize(t)) + if i == len(text_list) - 1: + tokens.extend(["[SEP]"]) + else: + tokens.extend(["[MASK]"]) + +input_ids_list = tokenizer.convert_tokens_to_ids(tokens) +input_ids = paddle.to_tensor([input_ids_list]) +with paddle.no_grad(): + pd_outputs = model(input_ids)[0] +pd_outputs_sentence = "paddle: " +for i, id in enumerate(input_ids_list): + if id == tokenizer.convert_tokens_to_ids(["[MASK]"])[0]: + scores, index = paddle.nn.functional.softmax(pd_outputs[i], + -1).topk(5) + tokens = tokenizer.convert_ids_to_tokens(index.tolist()) + outputs = [] + for score, tk in zip(scores.tolist(), tokens): + outputs.append(f"{tk}={score}") + pd_outputs_sentence += "[" + "||".join(outputs) + "]" + " " + else: + pd_outputs_sentence += "".join( + tokenizer.convert_ids_to_tokens( + [id], skip_special_tokens=True)) + " " + +print(pd_outputs_sentence) +# paddle: 本 院 经 审 查 认 为 , 本 案 [因=0.27444931864738464||经=0.18613006174564362||系=0.09408623725175858||的=0.07536833733320236||就=0.033634234219789505] 民 间 借 贷 纠 纷 申 请 再 审 案 件 , 应 重 点 审 查 二 审 判 决 是 否 存 在 错 误 的 情 形 。 +``` + +# 权重来源 + +https://huggingface.co/hfl/chinese-legal-electra-small-generator +谷歌和斯坦福大学发布了一种名为 ELECTRA 的新预训练模型,与 BERT 及其变体相比,该模型具有非常紧凑的模型尺寸和相对具有竞争力的性能。 为进一步加快中文预训练模型的研究,HIT与科大讯飞联合实验室(HFL)发布了基于ELECTRA官方代码的中文ELECTRA模型。 与 BERT 及其变体相比,ELECTRA-small 只需 1/10 的参数就可以在几个 NLP 任务上达到相似甚至更高的分数。 +这个项目依赖于官方ELECTRA代码: https://github.com/google-research/electra diff --git a/community/junnyu/hfl-chinese-legal-electra-small-generator/files.json b/community/junnyu/hfl-chinese-legal-electra-small-generator/files.json new file mode 100644 index 000000000000..4e94c9591d9f --- /dev/null +++ b/community/junnyu/hfl-chinese-legal-electra-small-generator/files.json @@ -0,0 +1,6 @@ +{ + "model_config_file": "https://paddlenlp.bj.bcebos.com/models/transformers/community/junnyu/hfl-chinese-legal-electra-small-generator/model_config.json", + "model_state": "https://paddlenlp.bj.bcebos.com/models/transformers/community/junnyu/hfl-chinese-legal-electra-small-generator/model_state.pdparams", + "tokenizer_config_file": "https://paddlenlp.bj.bcebos.com/models/transformers/community/junnyu/hfl-chinese-legal-electra-small-generator/tokenizer_config.json", + "vocab_file": "https://paddlenlp.bj.bcebos.com/models/transformers/community/junnyu/hfl-chinese-legal-electra-small-generator/vocab.txt" +} diff --git a/community/junnyu/nlptown-bert-base-multilingual-uncased-sentiment/README.md b/community/junnyu/nlptown-bert-base-multilingual-uncased-sentiment/README.md index 48207bf2b348..174672cbdd7e 100644 --- a/community/junnyu/nlptown-bert-base-multilingual-uncased-sentiment/README.md +++ b/community/junnyu/nlptown-bert-base-multilingual-uncased-sentiment/README.md @@ -2,6 +2,7 @@ **介绍**:nlptown-bert-base-multilingual-uncased-sentiment是一个带有序列分类头的多语言BERT模型,该模型可用于对英语、荷兰语、德语、法语、西班牙语和意大利语这六种语言的商品评论进行情感分析。其中评论的情感标签为1-5之间的星级。 **模型结构**: **`BertForSequenceClassification`**,带有序列分类头的Bert模型。 + **适用下游任务**:**情感分类**,该权重已经在下游`Sentiment classification`任务上进行了微调,因此可直接使用。 ## 训练数据 @@ -19,14 +20,15 @@ 微调后的模型在每种语言的 5,000 条商品评论中获得了以下准确率: - Accuracy (exact) 完全匹配。 - Accuracy (off-by-1) 是模型预测的评分等级与人工给出的评分等级差值小于等于 1 所占的百分比。 + | Language | Accuracy (exact) | Accuracy (off-by-1) | -| -------- | ---------------------- | ------------------- | -| English | 67% | 95% -| Dutch | 57% | 93% -| German | 61% | 94% -| French | 59% | 94% -| Italian | 59% | 95% -| Spanish | 58% | 95% +| -------- | ---------------- | ------------------- | +| English | 67% | 95% | +| Dutch | 57% | 93% | +| German | 61% | 94% | +| French | 59% | 94% | +| Italian | 59% | 95% | +| Spanish | 58% | 95% | ## 联系方式 对于类似模型的问题、反馈和/或请求,请联系 [NLP Town](https://www.nlp.town)。 diff --git a/community/junnyu/tbs17-MathBERT/README.md b/community/junnyu/tbs17-MathBERT/README.md index 5114a943c899..b1452209a7a6 100644 --- a/community/junnyu/tbs17-MathBERT/README.md +++ b/community/junnyu/tbs17-MathBERT/README.md @@ -6,6 +6,7 @@ **模型结构**: **`BertForPretraining`**,带有`MLM`和`NSP`任务的Bert模型。 + **适用下游任务**:**数学领域相关的任务**,如:与数学领域相关的`句子级别分类`,`token级别分类`,`问答`等。 ## 训练数据 diff --git a/docs/model_zoo/transformers.rst b/docs/model_zoo/transformers.rst index dfeaed75a4fe..329e5118e9ea 100644 --- a/docs/model_zoo/transformers.rst +++ b/docs/model_zoo/transformers.rst @@ -209,6 +209,18 @@ Transformer预训练模型汇总 | |``chinese-electra-base`` | Chinese | 12-layer, 768-hidden, | | | | | 12-heads, _M parameters. | | | | | Trained on Chinese text. | +| +----------------------------------------------------------------------------------+--------------+-----------------------------------------+ +| |``junnyu/hfl-chinese-electra-180g-base-discriminator`` | Chinese | Discriminator, 12-layer, 768-hidden, | +| | | | 12-heads, 102M parameters. | +| | | | Trained on 180g Chinese text. | +| +----------------------------------------------------------------------------------+--------------+-----------------------------------------+ +| |``junnyu/hfl-chinese-electra-180g-small-ex-discriminator`` | Chinese | Discriminator, 24-layer, 256-hidden, | +| | | | 4-heads, 24M parameters. | +| | | | Trained on 180g Chinese text. | +| +----------------------------------------------------------------------------------+--------------+-----------------------------------------+ +| |``junnyu/hfl-chinese-legal-electra-small-generator`` | Chinese | Generator, 12-layer, 64-hidden, | +| | | | 1-heads, 3M parameters. | +| | | | Trained on Chinese legal corpus. | +--------------------+----------------------------------------------------------------------------------+--------------+-----------------------------------------+ |ERNIE_ |``ernie-1.0`` | Chinese | 12-layer, 768-hidden, | | | | | 12-heads, 108M parameters. | @@ -451,7 +463,7 @@ Transformer预训练模型适用任务汇总 +--------------------+-------------------------+----------------------+--------------------+-----------------+-----------------+ |DistilBert_ | ✅ | ✅ | ✅ | ❌ | ❌ | +--------------------+-------------------------+----------------------+--------------------+-----------------+-----------------+ -|ELECTRA_ | ✅ | ✅ | ❌ | ❌ | ❌ | +|ELECTRA_ | ✅ | ✅ | ❌ | ❌ | ✅ | +--------------------+-------------------------+----------------------+--------------------+-----------------+-----------------+ |ERNIE_ | ✅ | ✅ | ✅ | ❌ | ❌ | +--------------------+-------------------------+----------------------+--------------------+-----------------+-----------------+ diff --git a/paddlenlp/transformers/electra/README.md b/paddlenlp/transformers/electra/README.md deleted file mode 100644 index 39e5600211aa..000000000000 --- a/paddlenlp/transformers/electra/README.md +++ /dev/null @@ -1 +0,0 @@ -# ELECTRA diff --git a/paddlenlp/transformers/electra/modeling.py b/paddlenlp/transformers/electra/modeling.py index 7d52756ddfdc..91bb3dd3c5eb 100644 --- a/paddlenlp/transformers/electra/modeling.py +++ b/paddlenlp/transformers/electra/modeling.py @@ -11,23 +11,26 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os -import time -from typing import Optional, Tuple -from collections import OrderedDict import paddle import paddle.nn as nn -import paddle.tensor as tensor import paddle.nn.functional as F from .. import PretrainedModel, register_base_model __all__ = [ - 'ElectraModel', 'ElectraPretrainedModel', 'ElectraForTotalPretraining', - 'ElectraDiscriminator', 'ElectraGenerator', 'ElectraClassificationHead', - 'ElectraForSequenceClassification', 'ElectraForTokenClassification', - 'ElectraPretrainingCriterion' + 'ElectraModel', + 'ElectraPretrainedModel', + 'ElectraForTotalPretraining', + 'ElectraDiscriminator', + 'ElectraGenerator', + 'ElectraClassificationHead', + 'ElectraForSequenceClassification', + 'ElectraForTokenClassification', + 'ElectraPretrainingCriterion', + 'ElectraForMultipleChoice', + 'ElectraForQuestionAnswering', + 'ElectraForMaskedLM', ] @@ -232,15 +235,15 @@ class ElectraPretrainedModel(PretrainedModel): pretrained_resource_files_map = { "model_state": { "electra-small": - "http://paddlenlp.bj.bcebos.com/models/transformers/electra/electra-small.pdparams", + "https://paddlenlp.bj.bcebos.com/models/transformers/electra/electra-small.pdparams", "electra-base": - "http://paddlenlp.bj.bcebos.com/models/transformers/electra/electra-base.pdparams", + "https://paddlenlp.bj.bcebos.com/models/transformers/electra/electra-base.pdparams", "electra-large": - "http://paddlenlp.bj.bcebos.com/models/transformers/electra/electra-large.pdparams", + "https://paddlenlp.bj.bcebos.com/models/transformers/electra/electra-large.pdparams", "chinese-electra-small": - "http://paddlenlp.bj.bcebos.com/models/transformers/chinese-electra-small/chinese-electra-small.pdparams", + "https://paddlenlp.bj.bcebos.com/models/transformers/chinese-electra-small/chinese-electra-small.pdparams", "chinese-electra-base": - "http://paddlenlp.bj.bcebos.com/models/transformers/chinese-electra-base/chinese-electra-base.pdparams", + "https://paddlenlp.bj.bcebos.com/models/transformers/chinese-electra-base/chinese-electra-base.pdparams", } } @@ -321,8 +324,10 @@ class ElectraModel(ElectraPretrainedModel): vocab_size (int): Vocabulary size of `inputs_ids` in `ElectraModel`. Also is the vocab size of token embedding matrix. Defines the number of different tokens that can be represented by the `inputs_ids` passed when calling `ElectraModel`. + embedding_size (int, optional): + Dimensionality of the embedding layer. hidden_size (int, optional): - Dimensionality of the embedding layer, encoder layer and pooler layer. + Dimensionality of the encoder layer and pooler layer. num_hidden_layers (int, optional): Number of hidden layers in the Transformer encoder. num_attention_heads (int, optional): @@ -350,7 +355,7 @@ class ElectraModel(ElectraPretrainedModel): .. note:: A normal_initializer initializes weight matrices as normal distributions. - See :meth:`BertPretrainedModel.init_weights()` for how weights are initialized in `ElectraModel`. + See :meth:`ElectraPretrainedModel.init_weights()` for how weights are initialized in `ElectraModel`. pad_token_id (int, optional): The index of padding token in the token vocabulary. @@ -443,12 +448,16 @@ def forward(self, inputs = tokenizer("Welcome to use PaddlePaddle and PaddleNLP!") inputs = {k:paddle.to_tensor([v]) for (k, v) in inputs.items()} output = model(**inputs) + ''' if attention_mask is None: attention_mask = paddle.unsqueeze( (input_ids == self.pad_token_id ).astype(paddle.get_default_dtype()) * -1e9, axis=[1, 2]) + else: + if attention_mask.ndim == 2: + attention_mask = attention_mask.unsqueeze(axis=[1, 2]) embedding_output = self.embeddings( input_ids=input_ids, @@ -492,12 +501,12 @@ def forward(self, Args: input_ids (Tensor): See :class:`ElectraModel`. + token_type_ids (Tensor, optional): + See :class:`ElectraModel`. position_ids (Tensor, optional): See :class:`ElectraModel`. attention_mask (Tensor, optional): See :class:`ElectraModel`. - use_cache (bool, optional): - See :class:`ElectraModel`. Returns: Tensor: Returns tensor `logits`, the prediction result of replaced tokens. @@ -515,8 +524,8 @@ def forward(self, inputs = tokenizer("Welcome to use PaddlePaddle and PaddleNLP!") inputs = {k:paddle.to_tensor([v]) for (k, v) in inputs.items()} - output = model(**inputs) - logits = output[0] + logits = model(**inputs) + """ discriminator_sequence_output = self.electra( input_ids, token_type_ids, position_ids, attention_mask) @@ -550,7 +559,7 @@ def __init__(self, electra): self.electra.config["embedding_size"], self.electra.config["vocab_size"]) else: - self.generator_lm_head_bias = paddle.fluid.layers.create_parameter( + self.generator_lm_head_bias = self.create_parameter( shape=[self.electra.config["vocab_size"]], dtype=paddle.get_default_dtype(), is_bias=True) @@ -569,12 +578,12 @@ def forward(self, Args: input_ids (Tensor): See :class:`ElectraModel`. + token_type_ids (Tensor, optional): + See :class:`ElectraModel`. position_ids (Tensor, optional): See :class:`ElectraModel`. attention_mask (Tensor, optional): See :class:`ElectraModel`. - use_cache (bool, optional): - See :class:`ElectraModel`. Returns: Tensor: Returns tensor `prediction_scores`, the scores of Electra Generator. @@ -591,8 +600,8 @@ def forward(self, inputs = tokenizer("Welcome to use PaddlePaddle and PaddleNLP!") inputs = {k:paddle.to_tensor([v]) for (k, v) in inputs.items()} - output = model(**inputs) - prediction_scores = output[0] + prediction_scores = model(**inputs) + """ generator_sequence_output = self.electra(input_ids, token_type_ids, position_ids, attention_mask) @@ -665,17 +674,21 @@ class ElectraForSequenceClassification(ElectraPretrainedModel): An instance of ElectraModel. num_classes (int, optional): The number of classes. Defaults to `2`. - + dropout (float, optional): + The dropout probability for output of Electra. + If None, use the same value as `hidden_dropout_prob` of `ElectraModel` + instance `electra`. Defaults to None. """ - def __init__(self, electra, num_classes): + def __init__(self, electra, num_classes=2, dropout=None): super(ElectraForSequenceClassification, self).__init__() self.num_classes = num_classes self.electra = electra self.classifier = ElectraClassificationHead( - self.electra.config["hidden_size"], - self.electra.config["hidden_dropout_prob"], self.num_classes) - + hidden_size=self.electra.config["hidden_size"], + hidden_dropout_prob=dropout if dropout is not None else + self.electra.config["hidden_dropout_prob"], + num_classes=self.num_classes, ) self.init_weights() def forward(self, @@ -712,9 +725,8 @@ def forward(self, inputs = tokenizer("Welcome to use PaddlePaddle and PaddleNLP!") inputs = {k:paddle.to_tensor([v]) for (k, v) in inputs.items()} - outputs = model(**inputs) + logits = model(**inputs) - logits = outputs[0] """ sequence_output = self.electra(input_ids, token_type_ids, position_ids, attention_mask) @@ -734,14 +746,18 @@ class ElectraForTokenClassification(ElectraPretrainedModel): An instance of ElectraModel. num_classes (int, optional): The number of classes. Defaults to `2`. - + dropout (float, optional): + The dropout probability for output of Electra. + If None, use the same value as `hidden_dropout_prob` of `ElectraModel` + instance `electra`. Defaults to None. """ - def __init__(self, electra, num_classes): + def __init__(self, electra, num_classes=2, dropout=None): super(ElectraForTokenClassification, self).__init__() self.num_classes = num_classes self.electra = electra - self.dropout = nn.Dropout(self.electra.config["hidden_dropout_prob"]) + self.dropout = nn.Dropout(dropout if dropout is not None else + self.electra.config["hidden_dropout_prob"]) self.classifier = nn.Linear(self.electra.config["hidden_size"], self.num_classes) self.init_weights() @@ -780,9 +796,8 @@ def forward(self, inputs = tokenizer("Welcome to use PaddlePaddle and PaddleNLP!") inputs = {k:paddle.to_tensor([v]) for (k, v) in inputs.items()} - outputs = model(**inputs) + logits = model(**inputs) - logits = outputs[0] """ sequence_output = self.electra(input_ids, token_type_ids, position_ids, attention_mask) @@ -983,7 +998,7 @@ def forward(self, position_ids(Tensor, optional): See :class:`ElectraModel`. attention_mask (list, optional): - See :class:`AlbertModel`. + See :class:`ElectraModel`. raw_input_ids(Tensor, optional): Raw inputs used to get discriminator labels. Its data type should be `int64` and it has a shape of [batch_size, sequence_length]. @@ -1011,7 +1026,7 @@ def forward(self, and its shape is [batch_size, sequence_length]. - `attention_mask` (Tensor): - See :class:`AlbertModel`. Its data type should be bool. + See :class:`ElectraModel`. Its data type should be bool. """ @@ -1038,6 +1053,146 @@ def forward(self, return gen_logits, disc_logits, disc_labels, attention_mask +class ElectraPooler(nn.Layer): + def __init__(self, hidden_size, pool_act="gelu"): + super(ElectraPooler, self).__init__() + self.dense = nn.Linear(hidden_size, hidden_size) + self.activation = get_activation(pool_act) + self.pool_act = pool_act + + def forward(self, hidden_states): + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class ElectraForMultipleChoice(ElectraPretrainedModel): + """ + Electra Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a + softmax) e.g. for RocStories/SWAG tasks. + + Args: + electra (:class:`ElectraModel`): + An instance of ElectraModel. + num_choices (int, optional): + The number of choices. Defaults to `2`. + dropout (float, optional): + The dropout probability for output of Electra. + If None, use the same value as `hidden_dropout_prob` of `ElectraModel` + instance `electra`. Defaults to None. + """ + + def __init__(self, electra, num_choices=2, dropout=None): + super(ElectraForMultipleChoice, self).__init__() + self.num_choices = num_choices + self.electra = electra + self.sequence_summary = ElectraPooler( + self.electra.config["hidden_size"], pool_act="gelu") + self.dropout = nn.Dropout(dropout if dropout is not None else + self.electra.config["hidden_dropout_prob"]) + self.classifier = nn.Linear(self.electra.config["hidden_size"], 1) + self.init_weights() + + def forward(self, + input_ids=None, + token_type_ids=None, + position_ids=None, + attention_mask=None): + r""" + The ElectraForMultipleChoice forward method, overrides the __call__() special method. + + Args: + input_ids (Tensor): + See :class:`ElectraModel` and shape as [batch_size, num_choice, sequence_length]. + token_type_ids (Tensor, optional): + See :class:`ElectraModel` and shape as [batch_size, num_choice, sequence_length]. + position_ids(Tensor, optional): + See :class:`ElectraModel` and shape as [batch_size, num_choice, sequence_length]. + attention_mask (list, optional): + See :class:`ElectraModel` and shape as [batch_size, num_choice, sequence_length]. + + Returns: + Tensor: Returns tensor `reshaped_logits`, a tensor of the multiple choice classification logits. + Shape as `[batch_size, num_choice]` and dtype as `float32`. + + Example: + .. code-block:: + + import paddle + from paddlenlp.transformers import ElectraForMultipleChoice, ElectraTokenizer + from paddlenlp.data import Pad, Dict + + tokenizer = ElectraTokenizer.from_pretrained('electra-small') + model = ElectraForMultipleChoice.from_pretrained('electra-small', num_choices=2) + + data = [ + { + "question": "how do you turn on an ipad screen?", + "answer1": "press the volume button.", + "answer2": "press the lock button.", + "label": 1, + }, + { + "question": "how do you indent something?", + "answer1": "leave a space before starting the writing", + "answer2": "press the spacebar", + "label": 0, + }, + ] + + text = [] + text_pair = [] + for d in data: + text.append(d["question"]) + text_pair.append(d["answer1"]) + text.append(d["question"]) + text_pair.append(d["answer2"]) + + inputs = tokenizer(text, text_pair) + batchify_fn = lambda samples, fn=Dict( + { + "input_ids": Pad(axis=0, pad_val=tokenizer.pad_token_id), # input_ids + "token_type_ids": Pad( + axis=0, pad_val=tokenizer.pad_token_type_id + ), # token_type_ids + } + ): fn(samples) + inputs = batchify_fn(inputs) + + reshaped_logits = model( + input_ids=paddle.to_tensor(inputs[0], dtype="int64"), + token_type_ids=paddle.to_tensor(inputs[1], dtype="int64"), + ) + print(reshaped_logits.shape) + # [2, 2] + + """ + input_ids = input_ids.reshape( + (-1, input_ids.shape[-1])) # flat_input_ids: [bs*num_choice,seq_l] + + if token_type_ids is not None: + token_type_ids = token_type_ids.reshape( + (-1, token_type_ids.shape[-1])) + if position_ids is not None: + position_ids = position_ids.reshape((-1, position_ids.shape[-1])) + if attention_mask is not None: + attention_mask = attention_mask.reshape( + (-1, attention_mask.shape[-1])) + + sequence_output = self.electra(input_ids, token_type_ids, position_ids, + attention_mask) + pooled_output = self.sequence_summary(sequence_output) + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) # logits: (bs*num_choice,1) + reshaped_logits = logits.reshape( + (-1, self.num_choices)) # logits: (bs, num_choice) + + return reshaped_logits + + class ElectraPretrainingCriterion(paddle.nn.Layer): ''' @@ -1124,3 +1279,84 @@ def forward(self, generator_prediction_scores, disc_loss = disc_loss.sum() / total_positions.sum() return self.gen_weight * gen_loss + self.disc_weight * disc_loss + + +class ElectraForQuestionAnswering(ElectraPretrainedModel): + """ + Electra Model with a span classification head on top for extractive question-answering tasks like + SQuAD (a linear layers on top of the hidden-states output to compute `span start logits` and + `span end logits`). + + Args: + electra (:class:`ElectraModel`): + An instance of ElectraModel. + + """ + + def __init__(self, electra): + super(ElectraForQuestionAnswering, self).__init__() + self.electra = electra + self.classifier = nn.Linear(self.electra.config["hidden_size"], 2) + self.init_weights() + + def forward(self, + input_ids, + token_type_ids=None, + position_ids=None, + attention_mask=None): + r""" + The ElectraForQuestionAnswering forward method, overrides the __call__() special method. + + Args: + input_ids (Tensor): + See :class:`ElectraModel`. + token_type_ids (Tensor, optional): + See :class:`ElectraModel`. + position_ids(Tensor, optional): + See :class:`ElectraModel`. + attention_mask (list, optional): + See :class:`ElectraModel`. + Returns: + tuple: Returns tuple (`start_logits`, `end_logits`). + + With the fields: + + - `start_logits` (Tensor): + A tensor of the input token classification logits, indicates the start position of the labelled span. + Its data type should be float32 and its shape is [batch_size, sequence_length]. + + - `end_logits` (Tensor): + A tensor of the input token classification logits, indicates the end position of the labelled span. + Its data type should be float32 and its shape is [batch_size, sequence_length]. + + Example: + .. code-block:: + + import paddle + from paddlenlp.transformers import ElectraForQuestionAnswering, ElectraTokenizer + + tokenizer = ElectraTokenizer.from_pretrained('electra-small') + model = ElectraForQuestionAnswering.from_pretrained('electra-small') + + inputs = tokenizer("Welcome to use PaddlePaddle and PaddleNLP!") + inputs = {k:paddle.to_tensor([v]) for (k, v) in inputs.items()} + outputs = model(**inputs) + + start_logits = outputs[0] + end_logits = outputs[1] + + """ + sequence_output = self.electra( + input_ids, + token_type_ids, + position_ids=position_ids, + attention_mask=attention_mask) + logits = self.classifier(sequence_output) + logits = paddle.transpose(logits, perm=[2, 0, 1]) + start_logits, end_logits = paddle.unstack(x=logits, axis=0) + + return start_logits, end_logits + + +# ElectraForMaskedLM is the same as ElectraGenerator +ElectraForMaskedLM = ElectraGenerator diff --git a/paddlenlp/transformers/electra/tokenizer.py b/paddlenlp/transformers/electra/tokenizer.py index 5e70d24588b6..c4e0637fbc42 100644 --- a/paddlenlp/transformers/electra/tokenizer.py +++ b/paddlenlp/transformers/electra/tokenizer.py @@ -93,7 +93,7 @@ class ElectraTokenizer(PretrainedTokenizer): }, "chinese-electra-small": { "do_lower_case": True - } + }, } def __init__(self, diff --git a/tests/transformers/electra/__init__.py b/tests/transformers/electra/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/transformers/electra/test_modeling.py b/tests/transformers/electra/test_modeling.py new file mode 100644 index 000000000000..01559629cfae --- /dev/null +++ b/tests/transformers/electra/test_modeling.py @@ -0,0 +1,168 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import numpy as np +import paddle +from paddlenlp.transformers import ElectraForMaskedLM, \ + ElectraForMultipleChoice, ElectraForQuestionAnswering , ElectraModel,ElectraForSequenceClassification + +from common_test import CommonTest +import unittest + + +def create_input_data(config, seed=None): + ''' + the generated input data will be same if a specified seed is set + ''' + if seed is not None: + np.random.seed(seed) + + input_ids = np.random.randint( + low=0, + high=config['vocab_size'], + size=(config["batch_size"], config["seq_len"])) + + return input_ids + + +class TestElectraForSequenceClassification(CommonTest): + def set_input(self): + self.config = copy.deepcopy(ElectraModel.pretrained_init_configuration[ + 'electra-base']) + self.config['num_hidden_layers'] = 2 + self.config['vocab_size'] = 512 + self.config['attention_probs_dropout_prob'] = 0.0 + self.config['hidden_dropout_prob'] = 0.0 + self.config['intermediate_size'] = 1024 + self.config['seq_len'] = 64 + self.config['batch_size'] = 4 + self.input_ids = create_input_data(self.config) + + def set_output(self): + self.expected_shape = (self.config['batch_size'], 2) + + def setUp(self): + self.set_model_class() + self.set_input() + self.set_output() + + def set_model_class(self): + self.TEST_MODEL_CLASS = ElectraForSequenceClassification + + def check_testcase(self): + self.check_output_equal(self.output.numpy().shape, self.expected_shape) + + def test_forward(self): + config = copy.deepcopy(self.config) + del config['batch_size'] + del config['seq_len'] + + bert = ElectraModel(**config) + model = self.TEST_MODEL_CLASS(bert) + input_ids = paddle.to_tensor(self.input_ids, dtype="int64") + self.output = model(input_ids) + self.check_testcase() + + +class TestElectraForMaskedLM(TestElectraForSequenceClassification): + def set_model_class(self): + self.TEST_MODEL_CLASS = ElectraForMaskedLM + + def set_output(self): + self.expected_seq_shape = (self.config['batch_size'], + self.config['seq_len'], + self.config['vocab_size']) + + def test_forward(self): + config = copy.deepcopy(self.config) + del config['batch_size'] + del config['seq_len'] + + electra = ElectraModel(**config) + model = self.TEST_MODEL_CLASS(electra) + input_ids = paddle.to_tensor(self.input_ids, dtype="int64") + self.output = model(input_ids) + self.check_testcase() + + def check_testcase(self): + self.check_output_equal(self.output.numpy().shape, + self.expected_seq_shape) + + +class TestElectraForQuestionAnswering(TestElectraForSequenceClassification): + def set_model_class(self): + self.TEST_MODEL_CLASS = ElectraForQuestionAnswering + + def set_output(self): + self.expected_start_logit_shape = (self.config['batch_size'], + self.config['seq_len']) + self.expected_end_logit_shape = (self.config['batch_size'], + self.config['seq_len']) + + def check_testcase(self): + self.check_output_equal(self.output[0].numpy().shape, + self.expected_start_logit_shape) + self.check_output_equal(self.output[1].numpy().shape, + self.expected_end_logit_shape) + + +class TestElectraForMultipleChoice(TestElectraForSequenceClassification): + def set_input(self): + self.config = copy.deepcopy(ElectraModel.pretrained_init_configuration[ + 'electra-base']) + self.config['num_hidden_layers'] = 2 + self.config['vocab_size'] = 512 + self.config['attention_probs_dropout_prob'] = 0.0 + self.config['hidden_dropout_prob'] = 0.0 + self.config['intermediate_size'] = 1024 + self.config['seq_len'] = 64 + self.config['batch_size'] = 4 + self.config['num_choices'] = 2 + self.config['max_position_embeddings'] = 512 + + self.input_ids = create_input_data(self.config) + # [bs*num_choice,seq_l] -> [bs,num_choice,seq_l] + self.input_ids = np.reshape(self.input_ids, [ + self.config['batch_size'] // self.config['num_choices'], + self.config['num_choices'], -1 + ]) + + def set_model_class(self): + self.TEST_MODEL_CLASS = ElectraForMultipleChoice + + def set_output(self): + self.expected_logit_shape = (self.config['batch_size'] // + self.config['num_choices'], + self.config['num_choices']) + + def check_testcase(self): + self.check_output_equal(self.output.numpy().shape, + self.expected_logit_shape) + + def test_forward(self): + config = copy.deepcopy(self.config) + del config["num_choices"] + del config['batch_size'] + del config['seq_len'] + + electra = ElectraModel(**config) + model = self.TEST_MODEL_CLASS(electra) + input_ids = paddle.to_tensor(self.input_ids, dtype="int64") + self.output = model(input_ids) + self.check_testcase() + + +if __name__ == "__main__": + unittest.main()