diff --git a/megatron/fused_kernels/scaled_masked_softmax.cpp b/megatron/fused_kernels/scaled_masked_softmax.cpp index d5334710cf..1852aee6fd 100644 --- a/megatron/fused_kernels/scaled_masked_softmax.cpp +++ b/megatron/fused_kernels/scaled_masked_softmax.cpp @@ -32,6 +32,12 @@ torch::Tensor bwd_cuda( torch::Tensor const& softmax_results, float scale_factor); +int get_batch_per_block_cuda( + int query_seq_len, + int key_seq_len, + int batches, + int attn_heads); + torch::Tensor fwd( torch::Tensor const& input, torch::Tensor const& mask, @@ -63,6 +69,14 @@ torch::Tensor bwd( return bwd_cuda(output_grads, softmax_results, scale_factor); } +int get_batch_per_block( + int query_seq_len, + int key_seq_len, + int batches, + int attn_heads) { + return get_batch_per_block_cuda(query_seq_len, key_seq_len, batches, attn_heads); +} + } // end namespace scaled_masked_softmax } // end namespace fused_softmax } // end namespace multihead_attn @@ -71,7 +85,13 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("forward", &multihead_attn::fused_softmax::scaled_masked_softmax::fwd, "Self Multihead Attention scaled, time masked softmax -- Forward."); - m.def("backward", + + m.def("backward", &multihead_attn::fused_softmax::scaled_masked_softmax::bwd, "Self Multihead Attention scaled, time masked softmax -- Backward."); + + m.def("get_batch_per_block", + &multihead_attn::fused_softmax::scaled_masked_softmax::get_batch_per_block, + "Return Batch per block size." + ); } diff --git a/megatron/fused_kernels/scaled_masked_softmax.h b/megatron/fused_kernels/scaled_masked_softmax.h index e80bfe6478..45e8dcea27 100644 --- a/megatron/fused_kernels/scaled_masked_softmax.h +++ b/megatron/fused_kernels/scaled_masked_softmax.h @@ -310,9 +310,22 @@ __global__ void scaled_masked_softmax_warp_backward( } } } - } // end of anonymous namespace +int get_batch_per_block(int query_seq_len, int key_seq_len, int batches, int attn_heads){ + int log2_elements = log2_ceil(key_seq_len); + const int next_power_of_two = 1 << log2_elements; + + int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE; + int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1; + + constexpr int threads_per_block = 128; + int warps_per_block = (threads_per_block / warp_size); + int batches_per_block = warps_per_block * batches_per_warp; + + return batches_per_block; +} + template void dispatch_scaled_masked_softmax_forward( output_t *dst, diff --git a/megatron/fused_kernels/scaled_masked_softmax_cuda.cu b/megatron/fused_kernels/scaled_masked_softmax_cuda.cu index 7e8317c4f4..902d36dd0f 100644 --- a/megatron/fused_kernels/scaled_masked_softmax_cuda.cu +++ b/megatron/fused_kernels/scaled_masked_softmax_cuda.cu @@ -28,6 +28,11 @@ namespace multihead_attn { namespace fused_softmax { namespace scaled_masked_softmax { +int get_batch_per_block_cuda(int query_seq_len, int key_seq_len, int batches, int attn_heads){ + return get_batch_per_block(query_seq_len, key_seq_len, batches, attn_heads); +} + + torch::Tensor fwd_cuda( torch::Tensor const& input, torch::Tensor const& mask, diff --git a/megatron/fused_kernels/scaled_upper_triang_masked_softmax.h b/megatron/fused_kernels/scaled_upper_triang_masked_softmax.h index ca722cbbc6..6df83fc103 100644 --- a/megatron/fused_kernels/scaled_upper_triang_masked_softmax.h +++ b/megatron/fused_kernels/scaled_upper_triang_masked_softmax.h @@ -361,6 +361,7 @@ void dispatch_scaled_upper_triang_masked_softmax_forward( int warps_per_block = (threads_per_block / warp_size); int batches_per_block = warps_per_block * batches_per_warp; TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0); + int blocks_per_seq = attn_batches / batches_per_block; dim3 blocks(seq_len, blocks_per_seq, 1); dim3 threads(warp_size, warps_per_block, 1); @@ -451,6 +452,7 @@ void dispatch_scaled_upper_triang_masked_softmax_backward( int warps_per_block = (threads_per_block / warp_size); int batches_per_block = warps_per_block * batches_per_warp; TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0); + int blocks_per_seq = attn_batches / batches_per_block; dim3 blocks(seq_len, blocks_per_seq, 1); dim3 threads(warp_size, warps_per_block, 1); diff --git a/megatron/fused_kernels/tests/__init__.py b/megatron/fused_kernels/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/megatron/fused_kernels/tests/test_fused_kernels.py b/megatron/fused_kernels/tests/test_fused_kernels.py new file mode 100644 index 0000000000..f8d5027a1f --- /dev/null +++ b/megatron/fused_kernels/tests/test_fused_kernels.py @@ -0,0 +1,300 @@ +import math + +import torch +from torch.nn import LayerNorm + +from megatron.model.enums import AttnMaskType +from megatron.model.fused_layer_norm import MixedFusedLayerNorm +from megatron.model.fused_softmax import FusedScaleMaskSoftmax +from megatron.model.utils import attention_mask_func + + +def test_load_fused_kernels(): + try: + import fused_mix_prec_layer_norm_cuda + import scaled_masked_softmax_cuda + import scaled_upper_triang_masked_softmax_cuda + import torch + + print("[Success] load_fused_kernels") + except ImportError as e: + print("[Fail] load_fused_kernels") + raise e + + +def test_fused_softmax(): + bert = BertModel.from_pretrained("bert-base-cased").cuda().half() + tokenizer = BertTokenizer.from_pretrained("bert-base-cased") + test_text = ( + "Hello. How are you? I am fine thank you and you? yes Good. " + "hi hi hi hi hi hi hi hi hi hi hi hi hi" # 32 + ) + + tokens = tokenizer( + [test_text] * 4, + return_tensors="pt", + ) + + embedding_output = bert.embeddings( + input_ids=tokens["input_ids"].cuda(), + position_ids=None, + token_type_ids=tokens["token_type_ids"].cuda(), + inputs_embeds=None, + past_key_values_length=0, + ) + + # (bsz, 1, 1, seq_len) + mask = bert.get_extended_attention_mask( + attention_mask=tokens["attention_mask"].cuda(), + input_shape=tokens["input_ids"].shape, + device=bert.device, + ) + # (bsz, 1, seq_len, seq_len) + mask = mask.repeat(1, 1, mask.size()[-1], 1) + + attention = bert.encoder.layer[0].attention.self + key_layer = attention.transpose_for_scores(attention.key(embedding_output)) + query_layer = attention.transpose_for_scores(attention.query(embedding_output)) + + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + attention_scores /= math.sqrt(key_layer.size()[-1]) + + fused_softmax = ( + FusedScaleMaskSoftmax( + input_in_fp16=True, + input_in_bf16=False, + mask_func=attention_mask_func, + scale=None, + softmax_in_fp32=False, + attn_mask_type=AttnMaskType.padding, + scaled_masked_softmax_fusion=True, + ) + .cuda() + .half() + ) + + fused_softmax_output = fused_softmax( + attention_scores, + (mask != 0), + ) + + torch_softmax = ( + FusedScaleMaskSoftmax( + input_in_fp16=True, + input_in_bf16=False, + mask_func=attention_mask_func, + scale=None, + softmax_in_fp32=False, + attn_mask_type=AttnMaskType.padding, + scaled_masked_softmax_fusion=False, + ) + .cuda() + .half() + ) + + torch_softmax_output = torch_softmax( + attention_scores, + (mask != 0), + ) + + test_result = (fused_softmax_output - torch_softmax_output).abs() + + while test_result.dim() != 1: + test_result = test_result.mean(dim=-1) + + diff = test_result.mean(dim=-1) + + if diff <= 1e-3: + print( + f"\n[Success] test_fused_softmax" + f"\n > mean_difference={diff}" + f"\n > fused_values={fused_softmax_output[-1][-1][-1][:5].tolist()}" + f"\n > torch_values={torch_softmax_output[-1][-1][-1][:5].tolist()}" + ) + else: + print( + f"\n[Fail] test_fused_softmax" + f"\n > mean_difference={diff}, " + f"\n > fused_values={fused_softmax_output[-1][-1][-1][:5].tolist()}, " + f"\n > torch_values={torch_softmax_output[-1][-1][-1][:5].tolist()}" + ) + + +def test_fused_upper_triangle_mask_softmax(): + gpt = GPT2Model.from_pretrained("gpt2").cuda().half() + tokenizer = GPT2Tokenizer.from_pretrained("gpt2") + test_text = ( + "Hello. How are you? I am fine thank you and you? yes Good. " + "hi hi hi hi hi hi hi" # 24 + ) + + tokens = tokenizer( + [test_text] * 4, + return_tensors="pt", + ) + + attention_mask = tokens["attention_mask"].cuda() + attention_mask = attention_mask.view(attention_mask.size(0), -1) + attention_mask = attention_mask[:, None, None, :] + attention_mask = (1.0 - attention_mask) * -10000.0 + attention_mask = attention_mask.repeat(1, 1, attention_mask.size()[-1], 1) + attn = gpt.h[0] + + hidden_states = gpt.wte(tokens["input_ids"].cuda()) + q, k, v = attn.attn.c_attn(hidden_states).split(768, dim=-1) + q = attn.attn._split_heads(q, attn.attn.num_heads, attn.attn.head_dim) + k = attn.attn._split_heads(k, attn.attn.num_heads, attn.attn.head_dim) + attn_weights = torch.matmul(q, k.transpose(-1, -2)) + + sq, sk = q.size(-2), k.size(-2) + causal_mask = attn.attn.bias[:, :, sk - sq : sk, :sk].bool() + total_mask = ~(causal_mask & (attention_mask == 0)) + """ + tensor([[[[False, True, True, ..., True, True, True], + [False, False, True, ..., True, True, True], + [False, False, False, ..., True, True, True], + ..., + [False, False, False, ..., False, True, True], + [False, False, False, ..., False, False, True], + [False, False, False, ..., False, False, False]]] + """ + + fused_softmax = ( + FusedScaleMaskSoftmax( + input_in_fp16=True, + input_in_bf16=False, + mask_func=attention_mask_func, + scale=None, + softmax_in_fp32=False, + attn_mask_type=AttnMaskType.causal, + scaled_masked_softmax_fusion=True, + ) + .cuda() + .half() + ) + + fused_softmax_output = fused_softmax( + attn_weights, + total_mask, + ) + + torch_softmax = ( + FusedScaleMaskSoftmax( + input_in_fp16=True, + input_in_bf16=False, + mask_func=attention_mask_func, + scale=None, + softmax_in_fp32=False, + attn_mask_type=AttnMaskType.causal, + scaled_masked_softmax_fusion=False, + ) + .cuda() + .half() + ) + + torch_softmax_output = torch_softmax( + attn_weights, + total_mask, + ) + + test_result = (fused_softmax_output - torch_softmax_output).abs() + + while test_result.dim() != 1: + test_result = test_result.mean(dim=-1) + + diff = test_result.mean(dim=-1) + + if diff <= 1e-3: + print( + f"\n[Success] test_fused_upper_triangle_mask_softmax" + f"\n > mean_difference={diff}" + f"\n > fused_values={fused_softmax_output[-1][-1][-1][:5].tolist()}" + f"\n > torch_values={torch_softmax_output[-1][-1][-1][:5].tolist()}" + ) + else: + print( + f"\n[Fail] test_fused_upper_triangle_mask_softmax" + f"\n > mean_difference={diff}, " + f"\n > fused_values={fused_softmax_output[-1][-1][-1][:5].tolist()}, " + f"\n > torch_values={torch_softmax_output[-1][-1][-1][:5].tolist()}" + ) + + +def test_layer_norm(): + bert = BertModel.from_pretrained("bert-base-cased").cuda().half() + tokenizer = BertTokenizer.from_pretrained("bert-base-cased") + test_text = ( + "Hello. How are you? I am fine thank you and you? yes Good. " + "hi hi hi hi hi hi hi hi hi hi hi hi hi" # 32 + ) + + tokens = tokenizer( + [test_text] * 4, + return_tensors="pt", + ) + + # [bsz, seq_len, d_model] + embedding_output = ( + bert.embeddings( + input_ids=tokens["input_ids"].cuda(), + position_ids=None, + token_type_ids=tokens["token_type_ids"].cuda(), + inputs_embeds=None, + past_key_values_length=0, + ) + .cuda() + .half() + ) + + fused_layernorm_layer = ( + MixedFusedLayerNorm(normalized_shape=embedding_output.size(-1)).cuda().half() + ) + + torch_layernorm_layer = ( + LayerNorm(normalized_shape=embedding_output.size(-1)).cuda().half() + ) + + fused_output = fused_layernorm_layer(embedding_output) + torch_output = torch_layernorm_layer(embedding_output) + test_result = (fused_output - torch_output).abs() + + while test_result.dim() != 1: + test_result = test_result.mean(dim=-1) + + diff = test_result.mean(dim=-1) + + if diff <= 1e-3: + print( + f"\n[Success] test_layer_norm" + f"\n > mean_difference={diff}" + f"\n > fused_values={fused_output[-1][-1][:5].tolist()}" + f"\n > torch_values={torch_output[-1][-1][:5].tolist()}" + ) + else: + print( + f"\n[Fail] test_layer_norm" + f"\n > mean_difference={diff}, " + f"\n > fused_values={fused_output[-1][-1][:5].tolist()}, " + f"\n > torch_values={torch_output[-1][-1][:5].tolist()}" + ) + + +if __name__ == "__main__": + try: + from transformers import BertTokenizer, GPT2Tokenizer + from transformers.models.bert.modeling_bert import BertModel + from transformers.models.gpt2.modeling_gpt2 import GPT2Model + import transformers + + transformers.logging.set_verbosity( + transformers.logging.FATAL, + ) + + except: + print("\n[Fail] Please install `transformers` package to test fused kernels\n") + exit(-1) + + test_load_fused_kernels() + test_fused_softmax() + test_fused_upper_triangle_mask_softmax() + test_layer_norm() diff --git a/megatron/model/fused_softmax.py b/megatron/model/fused_softmax.py index 097b29ef4c..7b047dfa61 100644 --- a/megatron/model/fused_softmax.py +++ b/megatron/model/fused_softmax.py @@ -13,7 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. + import torch +import torch.nn as nn from megatron.model.enums import AttnMaskType @@ -30,10 +32,10 @@ def forward(ctx, inputs, scale): import scaled_upper_triang_masked_softmax_cuda scale_t = torch.tensor([scale]) - softmax_results = scaled_upper_triang_masked_softmax_cuda.forward( inputs, scale_t[0] ) + ctx.save_for_backward(softmax_results, scale_t) return softmax_results @@ -42,10 +44,10 @@ def backward(ctx, output_grads): import scaled_upper_triang_masked_softmax_cuda softmax_results, scale_t = ctx.saved_tensors - input_grads = scaled_upper_triang_masked_softmax_cuda.backward( output_grads, softmax_results, scale_t[0] ) + return input_grads, None @@ -63,9 +65,7 @@ def forward(ctx, inputs, mask, scale): scale_t = torch.tensor([scale]) - softmax_results = scaled_masked_softmax_cuda.forward( - inputs, mask, scale_t[0] - ) + softmax_results = scaled_masked_softmax_cuda.forward(inputs, mask, scale_t[0]) ctx.save_for_backward(softmax_results, scale_t) return softmax_results @@ -81,16 +81,18 @@ def backward(ctx, output_grads): return input_grads, None, None -class FusedScaleMaskSoftmax(torch.nn.Module): +class FusedScaleMaskSoftmax(nn.Module): """ fused operation: scaling + mask + softmax + Arguments: input_in_fp16: flag to indicate if input in fp16 data format. + input_in_bf16: flag to indicate if input in bf16 data format. attn_mask_type: attention mask type (pad or causal) + scaled_masked_softmax_fusion: flag to indicate user want to use softmax fusion mask_func: mask function to be applied. softmax_in_fp32: if true, softmax in performed at fp32 precision. scale: scaling factor used in input tensor scaling. - """ def __init__( @@ -106,8 +108,9 @@ def __init__( super(FusedScaleMaskSoftmax, self).__init__() self.input_in_fp16 = input_in_fp16 self.input_in_bf16 = input_in_bf16 - assert not (self.input_in_fp16 and self.input_in_bf16),\ - 'both fp16 and bf16 flags cannot be active at the same time.' + assert not ( + self.input_in_fp16 and self.input_in_bf16 + ), "both fp16 and bf16 flags cannot be active at the same time." self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16 self.attn_mask_type = attn_mask_type self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion @@ -118,47 +121,72 @@ def __init__( assert ( self.scale is None or softmax_in_fp32 ), "softmax should be in fp32 when scaled" - + def forward(self, input, mask): # [b, np, sq, sk] assert input.dim() == 4 - data_size = input.size() - query_seq_len = data_size[-2] - key_seq_len = data_size[-1] - attn_batch_size = data_size[0] * data_size[1] - - # constraints on various tensor dimensions to enable warp based - # optimization and upper triangular optimization (for causal mask) - custom_kernel_constraint = key_seq_len > 16 and key_seq_len <= 2048 and \ - query_seq_len % 4 == 0 and attn_batch_size % 4 == 0 - - # invoke custom kernel - if self.input_in_float16 and mask is not None and \ - custom_kernel_constraint and self.scaled_masked_softmax_fusion: - scale = self.scale if self.scale is not None else 1.0 - - if self.attn_mask_type == AttnMaskType.causal: - assert query_seq_len == key_seq_len, \ - "causal mask is only for self attention" - input = input.view(-1, query_seq_len, key_seq_len) - probs = ScaledUpperTriangMaskedSoftmax.apply(input, scale) - probs = probs.view(*data_size) - else: - assert self.attn_mask_type == AttnMaskType.padding - probs = ScaledMaskedSoftmax.apply(input, mask, scale) + + if self.is_kernel_available(mask, *input.size()): + return self.forward_fused_softmax(input, mask) else: - if self.input_in_float16 and self.softmax_in_fp32: - input = input.float() + return self.forward_torch_softmax(input, mask) + + def is_kernel_available(self, mask, b, np, sq, sk): + attn_batches = b * np + + if ( + self.scaled_masked_softmax_fusion # user want to fuse + and self.input_in_float16 # input must be fp16 + and mask is not None # mask tensor must not be None + and 16 < sk <= 2048 # sk must be 16 ~ 2048 + and sq % 4 == 0 # sq must be divisor of 4 + and attn_batches % 4 == 0 # np * b must be divisor of 4 + ): + if 0 <= sk <= 2048: + batch_per_block = self.get_batch_per_block(sq, sk, b, np) + + if self.attn_mask_type == AttnMaskType.causal: + if attn_batches % batch_per_block == 0: + return True + else: + if sq % batch_per_block == 0: + return True + return False - if self.scale is not None: - input = input * self.scale - mask_output = self.mask_func(input, mask) if mask is not None else input - probs = torch.nn.Softmax(dim=-1)(mask_output) + def forward_fused_softmax(self, input, mask): + b, np, sq, sk = input.size() + scale = self.scale if self.scale is not None else 1.0 - if self.input_in_float16 and self.softmax_in_fp32: - if self.input_in_fp16: - probs = probs.half() - else: - probs = probs.bfloat16() + if self.attn_mask_type == AttnMaskType.causal: + assert sq == sk, "causal mask is only for self attention" + + # input is 3D tensor (attn_batches, sq, sk) + input = input.view(-1, sq, sk) + probs = ScaledUpperTriangMaskedSoftmax.apply(input, scale) + return probs.view(b, np, sq, sk) + else: + # input is 4D tensor (b, np, sq, sk) + return ScaledMaskedSoftmax.apply(input, mask, scale) + + def forward_torch_softmax(self, input, mask): + if self.input_in_float16 and self.softmax_in_fp32: + input = input.float() + + if self.scale is not None: + input = input * self.scale + mask_output = self.mask_func(input, mask) if mask is not None else input + probs = torch.nn.Softmax(dim=-1)(mask_output) + + if self.input_in_float16 and self.softmax_in_fp32: + if self.input_in_fp16: + probs = probs.half() + else: + probs = probs.bfloat16() return probs + + @staticmethod + def get_batch_per_block(sq, sk, b, np): + import scaled_masked_softmax_cuda + + return scaled_masked_softmax_cuda.get_batch_per_block(sq, sk, b, np)