-
Notifications
You must be signed in to change notification settings - Fork 2.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Error in fused softmax kernel result #132
Comments
here is GPT2 test. It has similar problem. def test_upper_mask_softmax():
gpt = GPT2Model.from_pretrained("gpt2").cuda().half()
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokens = tokenizer(
[
"Hello. How are you? I am fine thank you and you? yes Good. hi hello hello hello hello hello hello"
]
* 4,
return_tensors="pt",
)
attention_mask = tokens["attention_mask"].cuda()
attention_mask = attention_mask.view(attention_mask.size(0), -1)
attention_mask = attention_mask[:, None, None, :]
attention_mask = (1.0 - attention_mask) * -10000.0
attention_mask = attention_mask.repeat(1, 1, attention_mask.size()[-1], 1)
embedding = gpt.wte
c_attn = gpt.h[0].attn.c_attn
c_bias = gpt.h[0].attn.bias
_split_heads = gpt.h[0].attn._split_heads
num_heads = gpt.h[0].attn.num_heads
head_dim = gpt.h[0].attn.head_dim
hidden_states = embedding(tokens["input_ids"].cuda())
q, k, v = c_attn(hidden_states).split(768, dim=-1)
q = _split_heads(q, num_heads, head_dim)
k = _split_heads(k, num_heads, head_dim)
v = _split_heads(v, num_heads, head_dim)
attn_weights = torch.matmul(q, k.transpose(-1, -2))
q_length, k_length = q.size(-2), k.size(-2)
causal_mask = c_bias[:, :, k_length - q_length : k_length, :k_length].bool()
total_mask = ~(causal_mask & (attention_mask == 0))
"""
tensor([[[[False, True, True, ..., True, True, True],
[False, False, True, ..., True, True, True],
[False, False, False, ..., True, True, True],
...,
[False, False, False, ..., False, True, True],
[False, False, False, ..., False, False, True],
[False, False, False, ..., False, False, False]]]
"""
fused_softmax = FusedScaleMaskSoftmax(
mask_func=attention_mask_func,
attn_mask_type=AttnMaskType.causal,
input_in_fp16=True,
input_in_bf16=False,
scale=None,
softmax_in_fp32=False,
scaled_masked_softmax_fusion=True,
)
fused_softmax_output = fused_softmax(
attn_weights,
total_mask,
)
torch_softmax = FusedScaleMaskSoftmax(
mask_func=attention_mask_func,
attn_mask_type=AttnMaskType.causal,
input_in_fp16=True,
input_in_bf16=False,
scale=None,
softmax_in_fp32=False,
scaled_masked_softmax_fusion=False,
)
torch_softmax_output = torch_softmax(
attn_weights,
total_mask,
)
test_result = (
(fused_softmax_output[0][0][-1] - torch_softmax_output[0][0][-1]).abs().max()
)
if test_result <= 1e-6:
print("[Success] test_upper_mask_softmax")
else:
print("[Fail] test_upper_mask_softmax") |
layer norm kernel is works well. def test_layer_norm():
bert = BertModel.from_pretrained("bert-base-cased").cuda().half()
tokenizer = BertTokenizer.from_pretrained("bert-base-cased")
tokens = tokenizer(
[
"Hello. How are you? I am fine thank you and you? yes Good. hi hello hello hello hello"
]
* 4,
return_tensors="pt",
)
# [bsz, seq_len, d_model]
embedding_output = (
bert.embeddings(
input_ids=tokens["input_ids"].cuda(),
position_ids=None,
token_type_ids=tokens["token_type_ids"].cuda(),
inputs_embeds=None,
past_key_values_length=0,
)
.cuda()
.half()
)
fused_layernorm_layer = (
FusedLayerNorm(
normalized_shape=768,
eps=1e-5,
)
.cuda()
.half()
)
torch_layernorm_layer = (
LayerNorm(
normalized_shape=768,
eps=1e-5,
)
.cuda()
.half()
)
fused_output = fused_layernorm_layer(embedding_output)
torch_output = torch_layernorm_layer(embedding_output)
test_result = (fused_output - torch_output).abs()
while test_result.dim() != 1:
test_result = test_result.mean(dim=-1)
diff = test_result.mean(dim=-1)
if diff <= 1e-3:
print(
f"\n[Success] test_layer_norm"
f"\n > mean_difference={diff}"
f"\n > fused_values={fused_output[-1][-1][:5].tolist()}"
f"\n > torch_values={torch_output[-1][-1][:5].tolist()}"
)
else:
print(
f"\n[Fail] test_layer_norm"
f"\n > mean_difference={diff}, "
f"\n > fused_values={fused_output[-1][-1][:5].tolist()}, "
f"\n > torch_values={torch_output[-1][-1][:5].tolist()}"
) |
I think I'm missing something here... you
What is this |
they are same with megatron's modules. (no code change) https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/fused_softmax.py |
|
I think there is something wrong with the kernel or something with my test code. I experimented with compute capability 70 (V100), 75 (T4). but I don't think it's a problem with the development environment because the experimental results were the same. I also changed the batch size to 64, but the results were the same. |
Thanks for reporting this. It looks like it is triggered when the sequence length is < 128, which is not something we had tested. Will you test with the following changes?
|
I have one more suggestion. The constraint currently exist in Python code does not match all conditions for kernel execution. May I PR some improvements, including current fixes and test codes? (I already fixed some codes) |
@hyunwoongko That'd be great! |
I uploaded PR. please review this! |
PR is in ... closing |
I see the PR hasn't been merged. Closed too soon? I see this commit, which seems to be part of @hyunwoongko's PR: Is that the PR you were referring to? But where is the PR, I can only see that commit |
See #133 |
Exactly! The PR status is OPEN and not merged. Unless you meant something else by "PR is in" |
I was puzzled too. I thought the Megatron team had something they didn't want to merge into main branch. My PR has many improvements like code refactoring, test cases, etc. But they are not important parts. The most important part of PR (resolving kernel errors) has been reflected to main branch. If my findings can help someone, that's enough for me :) I think It's better to close my PR if Megatron team have something they didn't want to merge. |
The reason I'm asking, is that in order to sync the downstream we need a SHA commit which currently isn't there, since PR hasn't been merged. Normally an issue gets closed only once the corresponding PR is merged. Therefore please don't rush to close your PR and let's give a chance to the Megatron-LM team to clarify on what's happening. |
I see :) |
The PR is currently being tested. We will merge is in soon. |
@stas00 PR has been merged ! |
Yes, thank you! I will ask the Deepspeed folks to sync and then will sync to our fork from there. |
Problem ?
The result of the fused softmax layer is different from the result of the original torch softmax layer.
How to reproduce ?
The text was updated successfully, but these errors were encountered: