Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NEW MODEL] Add_XLM_model #2080

Merged
merged 22 commits into from
Jun 16, 2022
Merged

[NEW MODEL] Add_XLM_model #2080

merged 22 commits into from
Jun 16, 2022

Conversation

JunnYu
Copy link
Member

@JunnYu JunnYu commented May 8, 2022

PR types

New features

PR changes

Models

Description

Add-XLM-model
【飞桨论文复现挑战赛(第六期)】 110 XLM: Cross-lingual Language Model Pretraining

@JunnYu
Copy link
Member Author

JunnYu commented May 8, 2022

批量权重转换脚本。

import os
donot_transpose = [
    ".layer_norm", ".position_embeddings.", ".lang_embeddings.", ".embeddings."
]

def convert_pytorch_checkpoint_to_paddle(pytorch_checkpoint_path="pytorch_model.bin",
                                         paddle_dump_path="model_state.pdparams"):
    import torch
    import paddle
    from collections import OrderedDict
    pytorch_state_dict = torch.load(
        pytorch_checkpoint_path, map_location="cpu")
    paddle_state_dict = OrderedDict()
    for k, v in pytorch_state_dict.items():
        is_transpose = False
        if k[-7:] == ".weight":
            if not any(d in k for d in donot_transpose):
                if v.ndim == 2:
                    v = v.transpose(0, 1)
                    is_transpose = True
        oldk = k
        k = k.replace("transformer", "xlm")
        # remove pred_layer.proj.weight
        if "pred_layer.proj.weight" in k:
            continue
        if "pred_layer.proj.bias" in k:
            k = k.replace(".proj.", ".")
        print(f"Converting: {oldk} => {k} is_transpose {is_transpose}")
        paddle_state_dict[k] = v.data.numpy().astype("float32")
    paddle.save(paddle_state_dict, paddle_dump_path)
    
mapdict = {
        "xlm-mlm-xnli15-1024": "https://huggingface.co/xlm-mlm-xnli15-1024/resolve/main/merges.txt",
        "xlm-mlm-en-2048": "https://huggingface.co/xlm-mlm-en-2048/resolve/main/merges.txt",
        "xlm-mlm-ende-1024": "https://huggingface.co/xlm-mlm-ende-1024/resolve/main/merges.txt",
        "xlm-mlm-enfr-1024": "https://huggingface.co/xlm-mlm-enfr-1024/resolve/main/merges.txt",
        "xlm-mlm-enro-1024": "https://huggingface.co/xlm-mlm-enro-1024/resolve/main/merges.txt",
        "xlm-mlm-tlm-xnli15-1024": "https://huggingface.co/xlm-mlm-tlm-xnli15-1024/resolve/main/merges.txt",
        "xlm-clm-enfr-1024": "https://huggingface.co/xlm-clm-enfr-1024/resolve/main/merges.txt",
        "xlm-clm-ende-1024": "https://huggingface.co/xlm-clm-ende-1024/resolve/main/merges.txt",
        "xlm-mlm-17-1280": "https://huggingface.co/xlm-mlm-17-1280/resolve/main/merges.txt",
        "xlm-mlm-100-1280": "https://huggingface.co/xlm-mlm-100-1280/resolve/main/merges.txt",
    }

for name, url in mapdict.items():
    # mkdir
    os.makedirs(name, exist_ok=True)
    os.chdir(name)

    # convert model bin
    model_bin_url = url.replace("merges.txt", "pytorch_model.bin")
    os.system(f"wget {model_bin_url}")
    convert_pytorch_checkpoint_to_paddle()
    
    # convert vocab and merges
    merges_url = url
    os.system(f"wget {merges_url}")
    vocab_url = url.replace("merges.txt", "vocab.json")
    os.system(f"wget {vocab_url}")
    os.chdir("../")

@yingyibiao yingyibiao requested a review from gongel May 10, 2022 02:53
@gongel
Copy link
Member

gongel commented May 10, 2022

感谢贡献!麻烦加下example😊

@JunnYu
Copy link
Member Author

JunnYu commented May 10, 2022

@gongel 已添加。

Copy link
Member

@gongel gongel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

另外token_type_ids的返回有误,辛苦修改下。

paddlenlp/transformers/xlm/tokenizer.py Outdated Show resolved Hide resolved
@gongel
Copy link
Member

gongel commented Jun 16, 2022

token_type_ids facebook没有这个参数,huggingface有这个参数,不传入这个参数可对齐。对齐脚本如下:

import io
import os
import shutil
import importlib

import numpy as np
import paddle
import torch
import transformers as hfnlp
import paddlenlp
from paddlenlp.data import Pad
import paddlenlp.transformers as ppnlp

os.environ["TRANSFORMERS_CACHE"] = "./hf/"
os.environ["PPNLP_HOME"] = "./pdnlp/"


def compute_diff(torch_data, paddle_data):
	torch_data = torch_data.detach().numpy()
	paddle_data = paddle_data.numpy()
	out_dict = dict()
	diff = np.abs(torch_data - paddle_data)
	out_dict = "max: {}    mean: {}    min: {}".format(diff.max(), diff.mean(), diff.min())
	return out_dict


def compare_base(model_id):
	sentences = [
		"This is an example sentence.", 
		"Each sentence is converted .", 
		"欢迎使用 PaddlePaddle  。",
		"欢迎使用 PaddleNLP 。"
	]

	# Calculate HF output
	hf_tokenizer = hfnlp.XLMTokenizer.from_pretrained(model_id)
	hf_model = hfnlp.XLMModel.from_pretrained(model_id)
	hf_model.eval()
	with torch.no_grad():
		hf_inputs = hf_tokenizer(sentences, padding=True, return_tensors="pt")
		hf_inputs.pop('token_type_ids')
		print(hf_inputs)
		hf_out = hf_model(**hf_inputs).last_hidden_state
	
	# Calculate Paddle output
	pd_tokenizer = ppnlp.XLMTokenizer.from_pretrained(model_id)
	pd_model = ppnlp.XLMModel.from_pretrained(model_id)
	pd_model.eval()
	with paddle.no_grad():
		pd_inputs = pd_tokenizer(sentences, padding=True, return_attention_mask=True)
		print(pd_inputs)
		pd_out = pd_model(input_ids=paddle.to_tensor(pd_inputs['input_ids']), attention_mask=paddle.to_tensor(pd_inputs['attention_mask']))[0]

	return compute_diff(hf_out, pd_out)


print(compare_base('xlm-mlm-en-2048'))

@gongel gongel merged commit 07bafdc into PaddlePaddle:develop Jun 16, 2022
@gongel
Copy link
Member

gongel commented Jun 27, 2022

Tokenizer‘bug has been fixed by #2549 and #2551

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants