Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

采用长序列输入发生异常,roformer是否支持不定长度输入 #36

Closed
likestudy opened this issue Jun 30, 2022 · 2 comments
Closed

Comments

@likestudy
Copy link

您好,非常感谢您的开源以及提供pip安装。

在使用rofromer时,使用短序列进行输入正常(<512),但使用过长输入会报错并停止运行。请问rofomer_pytorch是否支持变长输入呢?

报错信息主要为:

/opt/conda/conda-bld/pytorch_1646755861072/work/aten/src/ATen/native/cuda/Indexing.cu:703: indexSelectLargeIndex: block: [335,0,0], thread: [93,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/opt/conda/conda-bld/pytorch_1646755861072/work/aten/src/ATen/native/cuda/Indexing.cu:703: indexSelectLargeIndex: block: [335,0,0], thread: [94,0,0] Assertion `srcIndex < srcSelectDimSize` failed.

  File "XXX/roformer/modeling_roformer.py", line 1075, in forward
    attention_mask, input_shape, device, past_key_values_length
  File "XXX/roformer/modeling_roformer.py", line 1158, in get_extended_attention_mask
    extended_attention_mask = extended_attention_mask.to(dtype=self.dtype)  # fp16 compatibility
RuntimeError: CUDA error: device-side assert triggered
@JunnYu
Copy link
Owner

JunnYu commented Jun 30, 2022

import torch
from transformers import BertTokenizer
from roformer import RoFormerForMaskedLM
text = "今天[MASK]很好,我[MASK]去公园玩。"
tokenizer = BertTokenizer.from_pretrained("junnyu/roformer_chinese_char_base")
pt_model = RoFormerForMaskedLM.from_pretrained("junnyu/roformer_chinese_char_base", max_position_embeddings=1024)

pt_inputs = tokenizer(text, return_tensors="pt", padding="max_length", max_length=1024)
# pytorch
with torch.no_grad():
    pt_outputs = pt_model(**pt_inputs).logits[0]
pt_outputs_sentence = "pytorch: "
for i, id in enumerate(tokenizer.encode(text)):
    if id == tokenizer.mask_token_id:
        tokens = tokenizer.convert_ids_to_tokens(pt_outputs[i].topk(k=5)[1])
        pt_outputs_sentence += "[" + "||".join(tokens) + "]"
    else:
        pt_outputs_sentence += "".join(
            tokenizer.convert_ids_to_tokens([id], skip_special_tokens=True)
        )
print(pt_outputs_sentence)
  • 主要在from_pretrained的时候,指定一下max_position_embeddings参数就可以了。
  • 如max_position_embeddings=1024,表示模型最大输入长达1024长度的输入。

@likestudy
Copy link
Author

非常感谢您的回复,该问题已解决

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants