Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error in fused softmax kernel result #132

Closed
hyunwoongko opened this issue Aug 12, 2021 · 22 comments
Closed

Error in fused softmax kernel result #132

hyunwoongko opened this issue Aug 12, 2021 · 22 comments

Comments

@hyunwoongko
Copy link
Contributor

hyunwoongko commented Aug 12, 2021

Problem ?

스크린샷 2021-08-12 오전 11 28 52

The result of the fused softmax layer is different from the result of the original torch softmax layer.

How to reproduce ?

import math

import torch
from torch.nn import Softmax
from transformers import BertTokenizer
from transformers.models.bert.modeling_bert import BertModel
from fused import FusedScaleMaskSoftmax
from fused import AttnMaskType

def load_fused_kernels():
    try:
        import fused_mix_prec_layer_norm_cuda
        import scaled_masked_softmax_cuda
        import scaled_upper_triang_masked_softmax_cuda
        import torch

        print("[Success] load_fused_kernels")
    except ImportError as e:
        print("[Fail] load_fused_kernels")
        raise e


def attention_mask_func(attention_scores, attention_mask):
    attention_scores.masked_fill_(attention_mask, -10000.0)
    return attention_scores


def test_softmax():
    bert = BertModel.from_pretrained("bert-base-cased").cuda().half()
    tokenizer = BertTokenizer.from_pretrained("bert-base-cased")

    # len_query=24, batch_per_block=8 (in my setting)
    tokens = tokenizer(
        [
            "Hello. How are you? I am fine thank you and you? yes Good. hi hello hello hello hello"
        ]
        * 4,
        return_tensors="pt",
    )

    embedding_output = bert.embeddings(
        input_ids=tokens["input_ids"].cuda(),
        position_ids=None,
        token_type_ids=tokens["token_type_ids"].cuda(),
        inputs_embeds=None,
        past_key_values_length=0,
    )

    # (bsz, 1, 1, seq_len), all values are 0.
    mask = bert.get_extended_attention_mask(
        attention_mask=tokens["attention_mask"].cuda(),
        input_shape=tokens["input_ids"].shape,
        device=bert.device,
    )
    # (bsz, 1, seq_len, seq_len)
    mask = mask.repeat(1, 1, mask.size()[-1], 1)

    attention = bert.encoder.layer[0].attention.self
    query_proj = attention.query
    key_proj = attention.key
    value_proj = attention.value

    key_layer = attention.transpose_for_scores(key_proj(embedding_output))
    value_layer = attention.transpose_for_scores(value_proj(embedding_output))
    query_layer = attention.transpose_for_scores(query_proj(embedding_output))

    attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
    attention_scores /= math.sqrt(key_layer.size()[-1])

    fused_softmax = FusedScaleMaskSoftmax(
        mask_func=attention_mask_func,
        attn_mask_type=AttnMaskType.padding,
        input_in_fp16=True,
        input_in_bf16=False,
        scale=None,
        softmax_in_fp32=False,
        scaled_masked_softmax_fusion=True,
    )

    fused_softmax_output = fused_softmax(
        attention_scores,
        (mask != 0),
    )

    torch_softmax = FusedScaleMaskSoftmax(
        mask_func=attention_mask_func,
        attn_mask_type=AttnMaskType.padding,
        input_in_fp16=True,
        input_in_bf16=False,
        scale=None,
        softmax_in_fp32=False,
        scaled_masked_softmax_fusion=False,
    )

    torch_softmax_output = torch_softmax(
        attention_scores,
        (mask != 0),
    )

    print("fused (turn on fusion):", fused_softmax_output[0][0][0])
    print("\n")
    print("fused (turn off fusion):", torch_softmax_output[0][0][0])

    torch_softmax = torch.nn.Softmax(dim=-1)
    torch_softmax_output = torch_softmax(attention_scores)

    print("\n")
    print("torch softmax", torch_softmax_output[0][0][0])


if __name__ == "__main__":
    load_fused_kernels()
    test_softmax()
@hyunwoongko
Copy link
Contributor Author

hyunwoongko commented Aug 12, 2021

here is GPT2 test. It has similar problem.

def test_upper_mask_softmax():
    gpt = GPT2Model.from_pretrained("gpt2").cuda().half()
    tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
    tokens = tokenizer(
        [
            "Hello. How are you? I am fine thank you and you? yes Good. hi hello hello hello hello hello hello"
        ]
        * 4,
        return_tensors="pt",
    )

    attention_mask = tokens["attention_mask"].cuda()
    attention_mask = attention_mask.view(attention_mask.size(0), -1)
    attention_mask = attention_mask[:, None, None, :]
    attention_mask = (1.0 - attention_mask) * -10000.0
    attention_mask = attention_mask.repeat(1, 1, attention_mask.size()[-1], 1)

    embedding = gpt.wte
    c_attn = gpt.h[0].attn.c_attn
    c_bias = gpt.h[0].attn.bias
    _split_heads = gpt.h[0].attn._split_heads
    num_heads = gpt.h[0].attn.num_heads
    head_dim = gpt.h[0].attn.head_dim

    hidden_states = embedding(tokens["input_ids"].cuda())
    q, k, v = c_attn(hidden_states).split(768, dim=-1)
    q = _split_heads(q, num_heads, head_dim)
    k = _split_heads(k, num_heads, head_dim)
    v = _split_heads(v, num_heads, head_dim)

    attn_weights = torch.matmul(q, k.transpose(-1, -2))
    q_length, k_length = q.size(-2), k.size(-2)
    causal_mask = c_bias[:, :, k_length - q_length : k_length, :k_length].bool()
    total_mask = ~(causal_mask & (attention_mask == 0))
    """
    tensor([[[[False,  True,  True,  ...,  True,  True,  True],
              [False, False,  True,  ...,  True,  True,  True],
              [False, False, False,  ...,  True,  True,  True],
              ...,
              [False, False, False,  ..., False,  True,  True],
              [False, False, False,  ..., False, False,  True],
              [False, False, False,  ..., False, False, False]]]
    """

    fused_softmax = FusedScaleMaskSoftmax(
        mask_func=attention_mask_func,
        attn_mask_type=AttnMaskType.causal,
        input_in_fp16=True,
        input_in_bf16=False,
        scale=None,
        softmax_in_fp32=False,
        scaled_masked_softmax_fusion=True,
    )

    fused_softmax_output = fused_softmax(
        attn_weights,
        total_mask,
    )

    torch_softmax = FusedScaleMaskSoftmax(
        mask_func=attention_mask_func,
        attn_mask_type=AttnMaskType.causal,
        input_in_fp16=True,
        input_in_bf16=False,
        scale=None,
        softmax_in_fp32=False,
        scaled_masked_softmax_fusion=False,
    )

    torch_softmax_output = torch_softmax(
        attn_weights,
        total_mask,
    )

    test_result = (
        (fused_softmax_output[0][0][-1] - torch_softmax_output[0][0][-1]).abs().max()
    )

    if test_result <= 1e-6:
        print("[Success] test_upper_mask_softmax")
    else:
        print("[Fail] test_upper_mask_softmax")

@hyunwoongko
Copy link
Contributor Author

layer norm kernel is works well.

def test_layer_norm():
    bert = BertModel.from_pretrained("bert-base-cased").cuda().half()
    tokenizer = BertTokenizer.from_pretrained("bert-base-cased")
    tokens = tokenizer(
        [
            "Hello. How are you? I am fine thank you and you? yes Good. hi hello hello hello hello"
        ]
        * 4,
        return_tensors="pt",
    )

    # [bsz, seq_len, d_model]
    embedding_output = (
        bert.embeddings(
            input_ids=tokens["input_ids"].cuda(),
            position_ids=None,
            token_type_ids=tokens["token_type_ids"].cuda(),
            inputs_embeds=None,
            past_key_values_length=0,
        )
        .cuda()
        .half()
    )

    fused_layernorm_layer = (
        FusedLayerNorm(
            normalized_shape=768,
            eps=1e-5,
        )
        .cuda()
        .half()
    )

    torch_layernorm_layer = (
        LayerNorm(
            normalized_shape=768,
            eps=1e-5,
        )
        .cuda()
        .half()
    )

    fused_output = fused_layernorm_layer(embedding_output)
    torch_output = torch_layernorm_layer(embedding_output)
    test_result = (fused_output - torch_output).abs()

    while test_result.dim() != 1:
        test_result = test_result.mean(dim=-1)

    diff = test_result.mean(dim=-1)

    if diff <= 1e-3:
        print(
            f"\n[Success] test_layer_norm"
            f"\n > mean_difference={diff}"
            f"\n > fused_values={fused_output[-1][-1][:5].tolist()}"
            f"\n > torch_values={torch_output[-1][-1][:5].tolist()}"
        )
    else:
        print(
            f"\n[Fail] test_layer_norm"
            f"\n > mean_difference={diff}, "
            f"\n > fused_values={fused_output[-1][-1][:5].tolist()}, "
            f"\n > torch_values={torch_output[-1][-1][:5].tolist()}"
        )

@hyunwoongko
Copy link
Contributor Author

image

@jaredcasper
Copy link
Collaborator

I think I'm missing something here... you import scaled_masked_softmax_cuda in load_fused_kernels but then don't use it. You seem to actually be testing the functions imported here:

from fused import FusedScaleMaskSoftmax
from fused import AttnMaskType

What is this fused module?

@hyunwoongko
Copy link
Contributor Author

hyunwoongko commented Aug 12, 2021

they are same with megatron's modules. (no code change)
I moved related codes to fused directory.

https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/fused_softmax.py
https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/enums.py

@hyunwoongko
Copy link
Contributor Author

hyunwoongko commented Aug 12, 2021

import scaled_masked_softmax_cuda in load_fused_kernels is just import test.
they do not affect softmax test.

@hyunwoongko
Copy link
Contributor Author

hyunwoongko commented Aug 12, 2021

I think there is something wrong with the kernel or something with my test code. I experimented with compute capability 70 (V100), 75 (T4). but I don't think it's a problem with the development environment because the experimental results were the same. I also changed the batch size to 64, but the results were the same.

@jaredcasper
Copy link
Collaborator

Thanks for reporting this. It looks like it is triggered when the sequence length is < 128, which is not something we had tested.

Will you test with the following changes?

diff --git a/megatron/fused_kernels/scaled_masked_softmax.h b/megatron/fused_kernels/scaled_masked_softmax.h
index 78e97e4ec60d545a93bf628e6a0135aca707baf5..e80bfe647872dadf912a572560f84ad201315242 100644
--- a/megatron/fused_kernels/scaled_masked_softmax.h
+++ b/megatron/fused_kernels/scaled_masked_softmax.h
@@ -111,7 +111,7 @@ __global__ void scaled_masked_softmax_warp_forward(
     constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
     constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
     constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
-    constexpr int ELEMENTS_PER_LDG_STG = 4;
+    constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;
 
     // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
     // gridDim/blockIdx = (seq_len, attn_heads, batches) 
@@ -230,7 +230,7 @@ __global__ void scaled_masked_softmax_warp_backward(
     constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
     constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
     constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
-    constexpr int ELEMENTS_PER_LDG_STG = 4;
+    constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;
 
     // blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
     // gridDim/blockIdx = (seq_len, attn_heads, batches) 
diff --git a/megatron/fused_kernels/scaled_upper_triang_masked_softmax.h b/megatron/fused_kernels/scaled_upper_triang_masked_softmax.h
index addca0a0a3bbe5322c4e4522471f17a3c00e2ee1..ca722cbbc626929fe0b3d54c0fcdc74b594dc7b0 100644
--- a/megatron/fused_kernels/scaled_upper_triang_masked_softmax.h
+++ b/megatron/fused_kernels/scaled_upper_triang_masked_softmax.h
@@ -125,7 +125,7 @@ __global__ void scaled_upper_triang_masked_softmax_warp_forward(
     constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
     constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
     constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
-    constexpr int ELEMENTS_PER_LDG_STG = 4;
+    constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;
 
     int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x;
     int local_seq = blockIdx.x + 1; 
@@ -245,7 +245,7 @@ __global__ void scaled_upper_triang_masked_softmax_warp_backward(
     constexpr int WARP_SIZE = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
     constexpr int WARP_ITERATIONS = next_power_of_two / WARP_SIZE;
     constexpr int WARP_BATCH = (next_power_of_two <= 128) ? 2 : 1;
-    constexpr int ELEMENTS_PER_LDG_STG = 4;
+    constexpr int ELEMENTS_PER_LDG_STG = (WARP_ITERATIONS < 4) ? 1 : 4;
 
     int first_batch = (blockDim.y * blockIdx.y + threadIdx.y) * gridDim.x * WARP_BATCH + blockIdx.x;
     int local_seq = blockIdx.x + 1; 

@hyunwoongko
Copy link
Contributor Author

@jaredcasper

image
It works well! Thank you.

@hyunwoongko
Copy link
Contributor Author

hyunwoongko commented Aug 12, 2021

I have one more suggestion. The constraint currently exist in Python code does not match all conditions for kernel execution. May I PR some improvements, including current fixes and test codes? (I already fixed some codes)

@jaredcasper
Copy link
Collaborator

@hyunwoongko That'd be great!

@hyunwoongko
Copy link
Contributor Author

hyunwoongko commented Aug 12, 2021

#133

I uploaded PR. please review this!

@shoeybi
Copy link
Collaborator

shoeybi commented Aug 17, 2021

PR is in ... closing

@shoeybi shoeybi closed this as completed Aug 17, 2021
@stas00
Copy link
Contributor

stas00 commented Aug 20, 2021

I see the PR hasn't been merged. Closed too soon?

I see this commit, which seems to be part of @hyunwoongko's PR:
2387ce0

Is that the PR you were referring to? But where is the PR, I can only see that commit

@shoeybi
Copy link
Collaborator

shoeybi commented Aug 21, 2021

See #133

@stas00
Copy link
Contributor

stas00 commented Aug 21, 2021

Exactly! The PR status is OPEN and not merged.

Unless you meant something else by "PR is in"

@hyunwoongko
Copy link
Contributor Author

hyunwoongko commented Aug 21, 2021

I was puzzled too. I thought the Megatron team had something they didn't want to merge into main branch.

My PR has many improvements like code refactoring, test cases, etc. But they are not important parts. The most important part of PR (resolving kernel errors) has been reflected to main branch. If my findings can help someone, that's enough for me :) I think It's better to close my PR if Megatron team have something they didn't want to merge.

@stas00
Copy link
Contributor

stas00 commented Aug 21, 2021

The reason I'm asking, is that in order to sync the downstream we need a SHA commit which currently isn't there, since PR hasn't been merged.

Normally an issue gets closed only once the corresponding PR is merged.

Therefore please don't rush to close your PR and let's give a chance to the Megatron-LM team to clarify on what's happening.

@hyunwoongko
Copy link
Contributor Author

I see :)

@shoeybi
Copy link
Collaborator

shoeybi commented Aug 22, 2021

The PR is currently being tested. We will merge is in soon.

@hyunwoongko
Copy link
Contributor Author

@stas00 PR has been merged !

@stas00
Copy link
Contributor

stas00 commented Sep 2, 2021

Yes, thank you! I will ask the Deepspeed folks to sync and then will sync to our fork from there.

conglongli pushed a commit to conglongli/Megatron-DeepSpeed that referenced this issue Jun 15, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants