Skip to content

Commit

Permalink
Merge branch 'github_fused_softmax' into 'main'
Browse files Browse the repository at this point in the history
Fused softmax checks and additions from Github (#133)

See merge request ADLR/megatron-lm!312
  • Loading branch information
jaredcasper committed Aug 24, 2021
2 parents 23266c5 + bc7b353 commit 0be4052
Show file tree
Hide file tree
Showing 7 changed files with 415 additions and 47 deletions.
22 changes: 21 additions & 1 deletion megatron/fused_kernels/scaled_masked_softmax.cpp
Expand Up @@ -32,6 +32,12 @@ torch::Tensor bwd_cuda(
torch::Tensor const& softmax_results,
float scale_factor);

int get_batch_per_block_cuda(
int query_seq_len,
int key_seq_len,
int batches,
int attn_heads);

torch::Tensor fwd(
torch::Tensor const& input,
torch::Tensor const& mask,
Expand Down Expand Up @@ -63,6 +69,14 @@ torch::Tensor bwd(
return bwd_cuda(output_grads, softmax_results, scale_factor);
}

int get_batch_per_block(
int query_seq_len,
int key_seq_len,
int batches,
int attn_heads) {
return get_batch_per_block_cuda(query_seq_len, key_seq_len, batches, attn_heads);
}

} // end namespace scaled_masked_softmax
} // end namespace fused_softmax
} // end namespace multihead_attn
Expand All @@ -71,7 +85,13 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward",
&multihead_attn::fused_softmax::scaled_masked_softmax::fwd,
"Self Multihead Attention scaled, time masked softmax -- Forward.");
m.def("backward",

m.def("backward",
&multihead_attn::fused_softmax::scaled_masked_softmax::bwd,
"Self Multihead Attention scaled, time masked softmax -- Backward.");

m.def("get_batch_per_block",
&multihead_attn::fused_softmax::scaled_masked_softmax::get_batch_per_block,
"Return Batch per block size."
);
}
15 changes: 14 additions & 1 deletion megatron/fused_kernels/scaled_masked_softmax.h
Expand Up @@ -310,9 +310,22 @@ __global__ void scaled_masked_softmax_warp_backward(
}
}
}

} // end of anonymous namespace

int get_batch_per_block(int query_seq_len, int key_seq_len, int batches, int attn_heads){
int log2_elements = log2_ceil(key_seq_len);
const int next_power_of_two = 1 << log2_elements;

int warp_size = (next_power_of_two < C10_WARP_SIZE) ? next_power_of_two : C10_WARP_SIZE;
int batches_per_warp = (next_power_of_two <= 128) ? 2 : 1;

constexpr int threads_per_block = 128;
int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp;

return batches_per_block;
}

template<typename input_t, typename output_t, typename acc_t>
void dispatch_scaled_masked_softmax_forward(
output_t *dst,
Expand Down
5 changes: 5 additions & 0 deletions megatron/fused_kernels/scaled_masked_softmax_cuda.cu
Expand Up @@ -28,6 +28,11 @@ namespace multihead_attn {
namespace fused_softmax {
namespace scaled_masked_softmax {

int get_batch_per_block_cuda(int query_seq_len, int key_seq_len, int batches, int attn_heads){
return get_batch_per_block(query_seq_len, key_seq_len, batches, attn_heads);
}


torch::Tensor fwd_cuda(
torch::Tensor const& input,
torch::Tensor const& mask,
Expand Down
2 changes: 2 additions & 0 deletions megatron/fused_kernels/scaled_upper_triang_masked_softmax.h
Expand Up @@ -361,6 +361,7 @@ void dispatch_scaled_upper_triang_masked_softmax_forward(
int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp;
TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0);

int blocks_per_seq = attn_batches / batches_per_block;
dim3 blocks(seq_len, blocks_per_seq, 1);
dim3 threads(warp_size, warps_per_block, 1);
Expand Down Expand Up @@ -451,6 +452,7 @@ void dispatch_scaled_upper_triang_masked_softmax_backward(
int warps_per_block = (threads_per_block / warp_size);
int batches_per_block = warps_per_block * batches_per_warp;
TORCH_INTERNAL_ASSERT(attn_batches % batches_per_block == 0);

int blocks_per_seq = attn_batches / batches_per_block;
dim3 blocks(seq_len, blocks_per_seq, 1);
dim3 threads(warp_size, warps_per_block, 1);
Expand Down
Empty file.
300 changes: 300 additions & 0 deletions megatron/fused_kernels/tests/test_fused_kernels.py
@@ -0,0 +1,300 @@
import math

import torch
from torch.nn import LayerNorm

from megatron.model.enums import AttnMaskType
from megatron.model.fused_layer_norm import MixedFusedLayerNorm
from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.model.utils import attention_mask_func


def test_load_fused_kernels():
try:
import fused_mix_prec_layer_norm_cuda
import scaled_masked_softmax_cuda
import scaled_upper_triang_masked_softmax_cuda
import torch

print("[Success] load_fused_kernels")
except ImportError as e:
print("[Fail] load_fused_kernels")
raise e


def test_fused_softmax():
bert = BertModel.from_pretrained("bert-base-cased").cuda().half()
tokenizer = BertTokenizer.from_pretrained("bert-base-cased")
test_text = (
"Hello. How are you? I am fine thank you and you? yes Good. "
"hi hi hi hi hi hi hi hi hi hi hi hi hi" # 32
)

tokens = tokenizer(
[test_text] * 4,
return_tensors="pt",
)

embedding_output = bert.embeddings(
input_ids=tokens["input_ids"].cuda(),
position_ids=None,
token_type_ids=tokens["token_type_ids"].cuda(),
inputs_embeds=None,
past_key_values_length=0,
)

# (bsz, 1, 1, seq_len)
mask = bert.get_extended_attention_mask(
attention_mask=tokens["attention_mask"].cuda(),
input_shape=tokens["input_ids"].shape,
device=bert.device,
)
# (bsz, 1, seq_len, seq_len)
mask = mask.repeat(1, 1, mask.size()[-1], 1)

attention = bert.encoder.layer[0].attention.self
key_layer = attention.transpose_for_scores(attention.key(embedding_output))
query_layer = attention.transpose_for_scores(attention.query(embedding_output))

attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
attention_scores /= math.sqrt(key_layer.size()[-1])

fused_softmax = (
FusedScaleMaskSoftmax(
input_in_fp16=True,
input_in_bf16=False,
mask_func=attention_mask_func,
scale=None,
softmax_in_fp32=False,
attn_mask_type=AttnMaskType.padding,
scaled_masked_softmax_fusion=True,
)
.cuda()
.half()
)

fused_softmax_output = fused_softmax(
attention_scores,
(mask != 0),
)

torch_softmax = (
FusedScaleMaskSoftmax(
input_in_fp16=True,
input_in_bf16=False,
mask_func=attention_mask_func,
scale=None,
softmax_in_fp32=False,
attn_mask_type=AttnMaskType.padding,
scaled_masked_softmax_fusion=False,
)
.cuda()
.half()
)

torch_softmax_output = torch_softmax(
attention_scores,
(mask != 0),
)

test_result = (fused_softmax_output - torch_softmax_output).abs()

while test_result.dim() != 1:
test_result = test_result.mean(dim=-1)

diff = test_result.mean(dim=-1)

if diff <= 1e-3:
print(
f"\n[Success] test_fused_softmax"
f"\n > mean_difference={diff}"
f"\n > fused_values={fused_softmax_output[-1][-1][-1][:5].tolist()}"
f"\n > torch_values={torch_softmax_output[-1][-1][-1][:5].tolist()}"
)
else:
print(
f"\n[Fail] test_fused_softmax"
f"\n > mean_difference={diff}, "
f"\n > fused_values={fused_softmax_output[-1][-1][-1][:5].tolist()}, "
f"\n > torch_values={torch_softmax_output[-1][-1][-1][:5].tolist()}"
)


def test_fused_upper_triangle_mask_softmax():
gpt = GPT2Model.from_pretrained("gpt2").cuda().half()
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
test_text = (
"Hello. How are you? I am fine thank you and you? yes Good. "
"hi hi hi hi hi hi hi" # 24
)

tokens = tokenizer(
[test_text] * 4,
return_tensors="pt",
)

attention_mask = tokens["attention_mask"].cuda()
attention_mask = attention_mask.view(attention_mask.size(0), -1)
attention_mask = attention_mask[:, None, None, :]
attention_mask = (1.0 - attention_mask) * -10000.0
attention_mask = attention_mask.repeat(1, 1, attention_mask.size()[-1], 1)
attn = gpt.h[0]

hidden_states = gpt.wte(tokens["input_ids"].cuda())
q, k, v = attn.attn.c_attn(hidden_states).split(768, dim=-1)
q = attn.attn._split_heads(q, attn.attn.num_heads, attn.attn.head_dim)
k = attn.attn._split_heads(k, attn.attn.num_heads, attn.attn.head_dim)
attn_weights = torch.matmul(q, k.transpose(-1, -2))

sq, sk = q.size(-2), k.size(-2)
causal_mask = attn.attn.bias[:, :, sk - sq : sk, :sk].bool()
total_mask = ~(causal_mask & (attention_mask == 0))
"""
tensor([[[[False, True, True, ..., True, True, True],
[False, False, True, ..., True, True, True],
[False, False, False, ..., True, True, True],
...,
[False, False, False, ..., False, True, True],
[False, False, False, ..., False, False, True],
[False, False, False, ..., False, False, False]]]
"""

fused_softmax = (
FusedScaleMaskSoftmax(
input_in_fp16=True,
input_in_bf16=False,
mask_func=attention_mask_func,
scale=None,
softmax_in_fp32=False,
attn_mask_type=AttnMaskType.causal,
scaled_masked_softmax_fusion=True,
)
.cuda()
.half()
)

fused_softmax_output = fused_softmax(
attn_weights,
total_mask,
)

torch_softmax = (
FusedScaleMaskSoftmax(
input_in_fp16=True,
input_in_bf16=False,
mask_func=attention_mask_func,
scale=None,
softmax_in_fp32=False,
attn_mask_type=AttnMaskType.causal,
scaled_masked_softmax_fusion=False,
)
.cuda()
.half()
)

torch_softmax_output = torch_softmax(
attn_weights,
total_mask,
)

test_result = (fused_softmax_output - torch_softmax_output).abs()

while test_result.dim() != 1:
test_result = test_result.mean(dim=-1)

diff = test_result.mean(dim=-1)

if diff <= 1e-3:
print(
f"\n[Success] test_fused_upper_triangle_mask_softmax"
f"\n > mean_difference={diff}"
f"\n > fused_values={fused_softmax_output[-1][-1][-1][:5].tolist()}"
f"\n > torch_values={torch_softmax_output[-1][-1][-1][:5].tolist()}"
)
else:
print(
f"\n[Fail] test_fused_upper_triangle_mask_softmax"
f"\n > mean_difference={diff}, "
f"\n > fused_values={fused_softmax_output[-1][-1][-1][:5].tolist()}, "
f"\n > torch_values={torch_softmax_output[-1][-1][-1][:5].tolist()}"
)


def test_layer_norm():
bert = BertModel.from_pretrained("bert-base-cased").cuda().half()
tokenizer = BertTokenizer.from_pretrained("bert-base-cased")
test_text = (
"Hello. How are you? I am fine thank you and you? yes Good. "
"hi hi hi hi hi hi hi hi hi hi hi hi hi" # 32
)

tokens = tokenizer(
[test_text] * 4,
return_tensors="pt",
)

# [bsz, seq_len, d_model]
embedding_output = (
bert.embeddings(
input_ids=tokens["input_ids"].cuda(),
position_ids=None,
token_type_ids=tokens["token_type_ids"].cuda(),
inputs_embeds=None,
past_key_values_length=0,
)
.cuda()
.half()
)

fused_layernorm_layer = (
MixedFusedLayerNorm(normalized_shape=embedding_output.size(-1)).cuda().half()
)

torch_layernorm_layer = (
LayerNorm(normalized_shape=embedding_output.size(-1)).cuda().half()
)

fused_output = fused_layernorm_layer(embedding_output)
torch_output = torch_layernorm_layer(embedding_output)
test_result = (fused_output - torch_output).abs()

while test_result.dim() != 1:
test_result = test_result.mean(dim=-1)

diff = test_result.mean(dim=-1)

if diff <= 1e-3:
print(
f"\n[Success] test_layer_norm"
f"\n > mean_difference={diff}"
f"\n > fused_values={fused_output[-1][-1][:5].tolist()}"
f"\n > torch_values={torch_output[-1][-1][:5].tolist()}"
)
else:
print(
f"\n[Fail] test_layer_norm"
f"\n > mean_difference={diff}, "
f"\n > fused_values={fused_output[-1][-1][:5].tolist()}, "
f"\n > torch_values={torch_output[-1][-1][:5].tolist()}"
)


if __name__ == "__main__":
try:
from transformers import BertTokenizer, GPT2Tokenizer
from transformers.models.bert.modeling_bert import BertModel
from transformers.models.gpt2.modeling_gpt2 import GPT2Model
import transformers

transformers.logging.set_verbosity(
transformers.logging.FATAL,
)

except:
print("\n[Fail] Please install `transformers` package to test fused kernels\n")
exit(-1)

test_load_fused_kernels()
test_fused_softmax()
test_fused_upper_triangle_mask_softmax()
test_layer_norm()

0 comments on commit 0be4052

Please sign in to comment.