Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PaddlePaddle Hackathon 54 提交 #1086

Merged
merged 21 commits into from
Oct 25, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions community/junnyu/ckiplab-bert-base-chinese-ner/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
关于完整使用方法及其他信息,请参考 https://github.com/ckiplab/ckip-transformers 。

**模型结构**: **`BertForTokenClassification`**,带有token分类头的Bert模型。

**适用下游任务**:**命名实体识别**,该权重已经在下游`NER`任务上进行了微调,因此可直接使用。

# 使用示例
Expand Down
1 change: 1 addition & 0 deletions community/junnyu/ckiplab-bert-base-chinese-pos/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
关于完整使用方法及其他信息,请参考 https://github.com/ckiplab/ckip-transformers 。

**模型结构**: **`BertForTokenClassification`**,带有token分类头的Bert模型。

**适用下游任务**:**词性标注**,该权重已经在下游`POS`任务上进行了微调,因此可直接使用。

# 使用示例
Expand Down
1 change: 1 addition & 0 deletions community/junnyu/ckiplab-bert-base-chinese-ws/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
关于完整使用方法及其他信息,请参考 https://github.com/ckiplab/ckip-transformers 。

**模型结构**: **`BertForTokenClassification`**,带有token分类头的Bert模型。

**适用下游任务**:**分词**,该权重已经在下游`WS`任务上进行了微调,因此可直接使用。

# 使用示例
Expand Down
138 changes: 138 additions & 0 deletions community/junnyu/electra_compare.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle
import torch
import numpy as np
import paddlenlp.transformers as ppnlp
import transformers as hgnlp


def compare(a, b):
a = a.cpu().numpy()
b = b.cpu().numpy()
meandif = np.abs(a - b).mean()
maxdif = np.abs(a - b).max()
print("mean dif:", meandif)
print("max dif:", maxdif)


def compare_discriminator(
path="junnyu/hfl-chinese-electra-180g-base-discriminator"):
pdmodel = ppnlp.ElectraDiscriminator.from_pretrained(path)
ptmodel = ppnlp.ElectraForPreTraining.from_pretrained(path).cuda()
tokenizer = ppnlp.ElectraTokenizer.from_pretrained(path)
pdmodel.eval()
ptmodel.eval()
text = "欢迎使用paddlenlp!"
pdinputs = {
k: paddle.to_tensor(
v, dtype="int64").unsqueeze(0)
for k, v in tokenizer(text).items()
}
ptinputs = {
k: torch.tensor(
v, dtype=torch.long).unsqueeze(0).cuda()
for k, v in tokenizer(text).items()
}
with paddle.no_grad():
pd_logits = pdmodel(**pdinputs)

with torch.no_grad():
pt_logits = ptmodel(**ptinputs).logits

compare(pd_logits, pt_logits)


def compare_generator():
text = "本院经审查认为,本案[MASK]民间借贷纠纷申请再审案件,应重点审查二审判决是否存在错误的情形。"
# ppnlp
path = "junnyu/hfl-chinese-legal-electra-small-generator"
model = ppnlp.ElectraForMaskedLM.from_pretrained(path)
tokenizer = ppnlp.ElectraTokenizer.from_pretrained(path)
model.eval()
tokens = ["[CLS]"]
text_list = text.split("[MASK]")
for i, t in enumerate(text_list):
tokens.extend(tokenizer.tokenize(t))
if i == len(text_list) - 1:
tokens.extend(["[SEP]"])
else:
tokens.extend(["[MASK]"])

input_ids_list = tokenizer.convert_tokens_to_ids(tokens)
input_ids = paddle.to_tensor([input_ids_list])
with paddle.no_grad():
pd_outputs = model(input_ids)[0]
pd_outputs_sentence = "paddle: "
for i, id in enumerate(input_ids_list):
if id == tokenizer.convert_tokens_to_ids(["[MASK]"])[0]:
scores, index = paddle.nn.functional.softmax(pd_outputs[i],
-1).topk(5)
tokens = tokenizer.convert_ids_to_tokens(index.tolist())
outputs = []
for score, tk in zip(scores.tolist(), tokens):
outputs.append(f"{tk}={score}")
pd_outputs_sentence += "[" + "||".join(outputs) + "]" + " "
else:
pd_outputs_sentence += "".join(
tokenizer.convert_ids_to_tokens(
[id], skip_special_tokens=True)) + " "

print(pd_outputs_sentence)

# transformers
path = "hfl/chinese-legal-electra-small-generator"
config = hgnlp.ElectraConfig.from_pretrained(path)
config.hidden_size = 64
config.intermediate_size = 256
config.num_attention_heads = 1
model = hgnlp.ElectraForMaskedLM.from_pretrained(path, config=config)
tokenizer = hgnlp.ElectraTokenizer.from_pretrained(path)
model.eval()

inputs = tokenizer(text, return_tensors="pt")

with torch.no_grad():
pt_outputs = model(**inputs).logits[0]
pt_outputs_sentence = "pytorch: "
for i, id in enumerate(inputs["input_ids"][0].tolist()):
if id == tokenizer.convert_tokens_to_ids(["[MASK]"])[0]:
scores, index = torch.nn.functional.softmax(pt_outputs[i],
-1).topk(5)
tokens = tokenizer.convert_ids_to_tokens(index.tolist())
outputs = []
for score, tk in zip(scores.tolist(), tokens):
outputs.append(f"{tk}={score}")
pt_outputs_sentence += "[" + "||".join(outputs) + "]" + " "
else:
pt_outputs_sentence += "".join(
tokenizer.convert_ids_to_tokens(
[id], skip_special_tokens=True)) + " "

print(pt_outputs_sentence)


if __name__ == "__main__":
compare_discriminator(
path="junnyu/hfl-chinese-electra-180g-base-discriminator")
# # mean dif: 3.1698835e-06
# # max dif: 1.335144e-05
compare_discriminator(
path="junnyu/hfl-chinese-electra-180g-small-ex-discriminator")
# mean dif: 3.7930229e-06
# max dif: 1.04904175e-05
compare_generator()
# paddle: 本 院 经 审 查 认 为 , 本 案 [因=0.27444931864738464||经=0.18613006174564362||系=0.09408623725175858||的=0.07536833733320236||就=0.033634234219789505] 民 间 借 贷 纠 纷 申 请 再 审 案 件 , 应 重 点 审 查 二 审 判 决 是 否 存 在 错 误 的 情 形 。
# pytorch: 本 院 经 审 查 认 为 , 本 案 [因=0.2744344472885132||经=0.1861187219619751||系=0.09407979995012283||的=0.07537488639354706||就=0.03363779932260513] 民 间 借 贷 纠 纷 申 请 再 审 案 件 , 应 重 点 审 查 二 审 判 决 是 否 存 在 错 误 的 情 形 。
79 changes: 79 additions & 0 deletions community/junnyu/electra_convert_huggingface2paddle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from collections import OrderedDict
import argparse

huggingface_to_paddle = {
"embeddings.LayerNorm": "embeddings.layer_norm",
"encoder.layer": "encoder.layers",
"attention.self.query.": "self_attn.q_proj.",
"attention.self.key.": "self_attn.k_proj.",
"attention.self.value.": "self_attn.v_proj.",
"attention.output.dense.": "self_attn.out_proj.",
"intermediate.dense": "linear1",
"output.dense": "linear2",
"attention.output.LayerNorm": "norm1",
"output.LayerNorm": "norm2",
"generator_predictions.LayerNorm": "generator_predictions.layer_norm",
"generator_lm_head.bias": "generator_lm_head_bias",
}

skip_weights = ["electra.embeddings.position_ids"]
dont_transpose = ["_embeddings.weight", "LayerNorm."]


def convert_pytorch_checkpoint_to_paddle(pytorch_checkpoint_path,
paddle_dump_path):
import torch
import paddle
pytorch_state_dict = torch.load(pytorch_checkpoint_path, map_location="cpu")
paddle_state_dict = OrderedDict()
for k, v in pytorch_state_dict.items():
if k == "generator_lm_head.weight": continue
is_transpose = False
if k in skip_weights:
continue
if k[-7:] == ".weight":
if not any([w in k for w in dont_transpose]):
if v.ndim == 2:
v = v.transpose(0, 1)
is_transpose = True
oldk = k
for huggingface_name, paddle_name in huggingface_to_paddle.items():
k = k.replace(huggingface_name, paddle_name)

print(f"Converting: {oldk} => {k} | is_transpose {is_transpose}")
paddle_state_dict[k] = v.data.numpy()

paddle.save(paddle_state_dict, paddle_dump_path)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--pytorch_checkpoint_path",
default=r"MODEL\hfl-chinese-electra-180g-base-discriminator\pytorch_model.bin",
type=str,
required=False,
help="Path to the Pytorch checkpoint path.")
parser.add_argument(
"--paddle_dump_path",
default=r"MODEL\hfl-chinese-electra-180g-base-discriminator\model_state.pdparams",
type=str,
required=False,
help="Path to the output Paddle model.")
args = parser.parse_args()
convert_pytorch_checkpoint_to_paddle(args.pytorch_checkpoint_path,
args.paddle_dump_path)
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# 详细介绍
**介绍**:该模型是base版本的Electra discriminator模型,并且在180G的中文数据上进行训练。

**模型结构**: **`ElectraDiscriminator`**,带有判别器的中文Electra模型。

**适用下游任务**:**通用下游任务**,如:句子级别分类,token级别分类,抽取式问答等任务。

# 使用示例

```python
import paddle
from paddlenlp.transformers import ElectraDiscriminator, ElectraTokenizer

path = "junnyu/hfl-chinese-electra-180g-base-discriminator"
model = ElectraDiscriminator.from_pretrained(path)
tokenizer = ElectraTokenizer.from_pretrained(path)
model.eval()

text = "欢迎使用paddlenlp!"
inputs = {
k: paddle.to_tensor(
v, dtype="int64").unsqueeze(0)
for k, v in tokenizer(text).items()
}

with paddle.no_grad():
logits = model(**inputs)

print(logits.shape)

```

# 权重来源

https://huggingface.co/hfl/chinese-electra-180g-base-discriminator
谷歌和斯坦福大学发布了一种名为 ELECTRA 的新预训练模型,与 BERT 及其变体相比,该模型具有非常紧凑的模型尺寸和相对具有竞争力的性能。 为进一步加快中文预训练模型的研究,HIT与科大讯飞联合实验室(HFL)发布了基于ELECTRA官方代码的中文ELECTRA模型。 与 BERT 及其变体相比,ELECTRA-small 只需 1/10 的参数就可以在几个 NLP 任务上达到相似甚至更高的分数。
这个项目依赖于官方ELECTRA代码: https://github.com/google-research/electra
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"model_config_file": "https://paddlenlp.bj.bcebos.com/models/transformers/community/junnyu/hfl-chinese-electra-180g-base-discriminator/model_config.json",
"model_state": "https://paddlenlp.bj.bcebos.com/models/transformers/community/junnyu/hfl-chinese-electra-180g-base-discriminator/model_state.pdparams",
"tokenizer_config_file": "https://paddlenlp.bj.bcebos.com/models/transformers/community/junnyu/hfl-chinese-electra-180g-base-discriminator/tokenizer_config.json",
"vocab_file": "https://paddlenlp.bj.bcebos.com/models/transformers/community/junnyu/hfl-chinese-electra-180g-base-discriminator/vocab.txt"
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# 详细介绍
**介绍**:该模型是small版本的Electra discriminator模型,并且在180G的中文数据上进行训练。

**模型结构**: **`ElectraDiscriminator`**,带有判别器的中文Electra模型。

**适用下游任务**:**通用下游任务**,如:句子级别分类,token级别分类,抽取式问答等任务。

# 使用示例

```python
import paddle
from paddlenlp.transformers import ElectraDiscriminator,ElectraTokenizer

path = "junnyu/hfl-chinese-electra-180g-small-ex-discriminator"
model = ElectraDiscriminator.from_pretrained(path)
tokenizer = ElectraTokenizer.from_pretrained(path)
model.eval()

text = "欢迎使用paddlenlp!"
inputs = {
k: paddle.to_tensor(
v, dtype="int64").unsqueeze(0)
for k, v in tokenizer(text).items()
}

with paddle.no_grad():
logits = model(**inputs)

print(logits.shape)

```

# 权重来源

https://huggingface.co/hfl/chinese-electra-180g-small-ex-discriminator
谷歌和斯坦福大学发布了一种名为 ELECTRA 的新预训练模型,与 BERT 及其变体相比,该模型具有非常紧凑的模型尺寸和相对具有竞争力的性能。 为进一步加快中文预训练模型的研究,HIT与科大讯飞联合实验室(HFL)发布了基于ELECTRA官方代码的中文ELECTRA模型。 与 BERT 及其变体相比,ELECTRA-small 只需 1/10 的参数就可以在几个 NLP 任务上达到相似甚至更高的分数。
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"model_config_file": "https://paddlenlp.bj.bcebos.com/models/transformers/community/junnyu/hfl-chinese-electra-180g-small-ex-discriminator/model_config.json",
"model_state": "https://paddlenlp.bj.bcebos.com/models/transformers/community/junnyu/hfl-chinese-electra-180g-small-ex-discriminator/model_state.pdparams",
"tokenizer_config_file": "https://paddlenlp.bj.bcebos.com/models/transformers/community/junnyu/hfl-chinese-electra-180g-small-ex-discriminator/tokenizer_config.json",
"vocab_file": "https://paddlenlp.bj.bcebos.com/models/transformers/community/junnyu/hfl-chinese-electra-180g-small-ex-discriminator/vocab.txt"
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# 详细介绍
**介绍**:该模型是small版本的Electra generator模型,该模型在法律领域数据上进行了预训练。

**模型结构**: **`ElectraGenerator`**,带有生成器的中文Electra模型。

**适用下游任务**:**法律领域的下游任务**,如:法律领域的句子级别分类,法律领域的token级别分类,法律领域的抽取式问答等任务。
(注:生成器的效果不好,通常使用判别器进行下游任务微调)


# 使用示例

```python
import paddle
from paddlenlp.transformers import ElectraGenerator, ElectraTokenizer

text = "本院经审查认为,本案[MASK]民间借贷纠纷申请再审案件,应重点审查二审判决是否存在错误的情形。"
path = "junnyu/hfl-chinese-legal-electra-small-generator"
model = ElectraGenerator.from_pretrained(path)
model.eval()
tokenizer = ElectraTokenizer.from_pretrained(path)

tokens = ["[CLS]"]
text_list = text.split("[MASK]")
for i, t in enumerate(text_list):
tokens.extend(tokenizer.tokenize(t))
if i == len(text_list) - 1:
tokens.extend(["[SEP]"])
else:
tokens.extend(["[MASK]"])

input_ids_list = tokenizer.convert_tokens_to_ids(tokens)
input_ids = paddle.to_tensor([input_ids_list])
with paddle.no_grad():
pd_outputs = model(input_ids)[0]
pd_outputs_sentence = "paddle: "
for i, id in enumerate(input_ids_list):
if id == tokenizer.convert_tokens_to_ids(["[MASK]"])[0]:
scores, index = paddle.nn.functional.softmax(pd_outputs[i],
-1).topk(5)
tokens = tokenizer.convert_ids_to_tokens(index.tolist())
outputs = []
for score, tk in zip(scores.tolist(), tokens):
outputs.append(f"{tk}={score}")
pd_outputs_sentence += "[" + "||".join(outputs) + "]" + " "
else:
pd_outputs_sentence += "".join(
tokenizer.convert_ids_to_tokens(
[id], skip_special_tokens=True)) + " "

print(pd_outputs_sentence)
# paddle: 本 院 经 审 查 认 为 , 本 案 [因=0.27444931864738464||经=0.18613006174564362||系=0.09408623725175858||的=0.07536833733320236||就=0.033634234219789505] 民 间 借 贷 纠 纷 申 请 再 审 案 件 , 应 重 点 审 查 二 审 判 决 是 否 存 在 错 误 的 情 形 。
```

# 权重来源

https://huggingface.co/hfl/chinese-legal-electra-small-generator
谷歌和斯坦福大学发布了一种名为 ELECTRA 的新预训练模型,与 BERT 及其变体相比,该模型具有非常紧凑的模型尺寸和相对具有竞争力的性能。 为进一步加快中文预训练模型的研究,HIT与科大讯飞联合实验室(HFL)发布了基于ELECTRA官方代码的中文ELECTRA模型。 与 BERT 及其变体相比,ELECTRA-small 只需 1/10 的参数就可以在几个 NLP 任务上达到相似甚至更高的分数。
这个项目依赖于官方ELECTRA代码: https://github.com/google-research/electra
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"model_config_file": "https://paddlenlp.bj.bcebos.com/models/transformers/community/junnyu/hfl-chinese-legal-electra-small-generator/model_config.json",
"model_state": "https://paddlenlp.bj.bcebos.com/models/transformers/community/junnyu/hfl-chinese-legal-electra-small-generator/model_state.pdparams",
"tokenizer_config_file": "https://paddlenlp.bj.bcebos.com/models/transformers/community/junnyu/hfl-chinese-legal-electra-small-generator/tokenizer_config.json",
"vocab_file": "https://paddlenlp.bj.bcebos.com/models/transformers/community/junnyu/hfl-chinese-legal-electra-small-generator/vocab.txt"
}
Loading