## Validate Model Forward and Backward

In [138]:
import numpy as np
import torch

In [139]:
Dtype = 'bfloat16'
ATOL = {
    'float32': 1e-3,
    'bfloat16': 5e-2,
    'float16': 5e-2,
}
RTOL = {
    'float32': 1.3e-6,
    'bfloat16': 1e-2,
    'float16': 1e-2,
}
atol = ATOL[Dtype]
rtol = RTOL[Dtype]
QKV_LINEAR_FUSION = True

In [140]:
cache_dir = "/root/Megatron-LM/cache"

def inspect_output(hf_array, megatron_array):
    print(f"hf_array.shape: {hf_array.shape}, megatron_array.shape: {megatron_array.shape}")
    diff = np.abs(hf_array-megatron_array)
    min_diff = diff.min()
    max_diff = diff.max()
    mean_diff = diff.mean()
    print(f"min_diff: {min_diff}, max_diff: {max_diff}, mean_diff: {mean_diff}")
    r_diff = diff/(np.abs(hf_array)+1e-7)
    min_r_diff = r_diff.min()
    max_r_diff = r_diff.max()
    mean_r_diff = r_diff.mean()
    print(f"min_r_diff: {min_r_diff}, max_r_diff: {max_r_diff}, mean_r_diff: {mean_r_diff}")

In [141]:
embedding_hf = np.load("/root/Megatron-LM/cache/hf_model.embed_tokens.npy")
embedding_megatron = np.load("/root/Megatron-LM/cache/megatron_embedding.word_embeddings.npy")
embedding_flag = np.allclose(embedding_hf, embedding_megatron, atol=atol, rtol=rtol)
print("embedding: ", embedding_flag)
inspect_output(embedding_hf, embedding_megatron)
# embedding dropout
embedding_dropout_hf = np.load("/root/Megatron-LM/cache/hf_model.dropout.npy")
embedding_dropout_megatron = np.load("/root/Megatron-LM/cache/megatron_embedding.embedding_dropout.npy")
dropout_flag = np.allclose(embedding_dropout_hf, embedding_dropout_megatron.transpose(1,0,2), atol=atol, rtol=rtol)
print("embedding dropout: ", dropout_flag)
inspect_output(embedding_dropout_hf, embedding_dropout_megatron.transpose(1,0,2))

embedding:  True
hf_array.shape: (1, 128, 4096), megatron_array.shape: (1, 128, 4096)
min_diff: 0.0, max_diff: 0.0, mean_diff: 0.0
min_r_diff: 0.0, max_r_diff: 0.0, mean_r_diff: 0.0
embedding dropout:  True
hf_array.shape: (1, 128, 4096), megatron_array.shape: (1, 128, 4096)
min_diff: 0.0, max_diff: 0.0, mean_diff: 0.0
min_r_diff: 0.0, max_r_diff: 0.0, mean_r_diff: 0.0


In [142]:
if not QKV_LINEAR_FUSION:
    rmsnorm_output_hf = np.load(f"{cache_dir}/hf_model.layers.0.input_layernorm.npy")
    rmsnorm_output_megatron = np.load(f"{cache_dir}/megatron_decoder.layers.0.input_layernorm.npy").reshape(1, 128, -1)
    inspect_output(rmsnorm_output_hf, rmsnorm_output_megatron)

In [143]:
# qkv_proj
if QKV_LINEAR_FUSION:
    qkv_megatron = np.load(f"{cache_dir}/megatron_decoder.layers.0.self_attention.linear_qkv_o0.npy")
    q_megatron = qkv_megatron.transpose(1,0,2).reshape(1, 128, 8, -1)[:, :, :, :512].reshape(1, 128, -1)
    k_megatron = qkv_megatron.transpose(1,0,2).reshape(1, 128, 8, -1)[:, :, :, 512:640].reshape(1, 128, -1)
    v_megatron = qkv_megatron.transpose(1,0,2).reshape(1, 128, 8, -1)[:, :, :, 640:].reshape(1, 128, -1)
else:
    q_megatron = np.load(f"{cache_dir}/megatron_decoder.layers.0.self_attention.linear_q_o0.npy").reshape(1, 128, -1)
    k_megatron = np.load(f"{cache_dir}/megatron_decoder.layers.0.self_attention.linear_k_o0.npy").reshape(1, 128, -1)
    v_megatron = np.load(f"{cache_dir}/megatron_decoder.layers.0.self_attention.linear_v_o0.npy").reshape(1, 128, -1)
q_hf = np.load(f"{cache_dir}/hf_model.layers.0.self_attn.q_proj.npy")
k_hf = np.load(f"{cache_dir}/hf_model.layers.0.self_attn.k_proj.npy")
v_hf = np.load(f"{cache_dir}/hf_model.layers.0.self_attn.v_proj.npy")
layer_norm_flag = np.allclose(q_megatron, q_hf, atol=atol, rtol=rtol) 
print("linear q: ", layer_norm_flag)
inspect_output(q_megatron, q_hf)
layer_norm_flag = np.allclose(k_megatron, k_hf, atol=atol, rtol=rtol)
print("linear k: ", layer_norm_flag)
inspect_output(k_megatron, k_hf)
layer_norm_flag = np.allclose(v_megatron, v_hf, atol=atol, rtol=rtol)
print("linear v: ", layer_norm_flag)
inspect_output(v_megatron, v_hf)

linear q:  True
hf_array.shape: (1, 128, 4096), megatron_array.shape: (1, 128, 4096)
min_diff: 0.0, max_diff: 0.03125, mean_diff: 0.000721507822163403
min_r_diff: 0.0, max_r_diff: 420.6923522949219, mean_r_diff: 0.010125597007572651
linear k:  True
hf_array.shape: (1, 128, 1024), megatron_array.shape: (1, 128, 1024)
min_diff: 0.0, max_diff: 0.03125, mean_diff: 0.0011264100903645158
min_r_diff: 0.0, max_r_diff: 116.12904357910156, mean_r_diff: 0.011339498683810234
linear v:  True
hf_array.shape: (1, 128, 1024), megatron_array.shape: (1, 128, 1024)
min_diff: 0.0, max_diff: 0.00390625, mean_diff: 7.616392394993454e-05
min_r_diff: 0.0, max_r_diff: 275.34375, mean_r_diff: 0.030313245952129364


In [144]:
# attention output
for i in range(32):
    attention_output_hf = np.load(f"{cache_dir}/hf_model.layers.{i}.self_attn_o0.npy")
    attention_output_megatron = np.load(f"{cache_dir}/megatron_decoder.layers.{i}.self_attention_o0.npy").transpose(1,0,2)
    print(f"layer {i}: ", np.allclose(attention_output_hf, attention_output_megatron, atol=atol, rtol=rtol))
    inspect_output(attention_output_hf, attention_output_megatron)

layer 0:  True
hf_array.shape: (1, 128, 4096), megatron_array.shape: (1, 128, 4096)
min_diff: 0.0, max_diff: 0.00390625, mean_diff: 1.3048614164290484e-05
min_r_diff: 0.0, max_r_diff: 610.3515625, mean_r_diff: 0.08150769770145416
layer 1:  True
hf_array.shape: (1, 128, 4096), megatron_array.shape: (1, 128, 4096)
min_diff: 0.0, max_diff: 0.0234375, mean_diff: 0.00023041712120175362
min_r_diff: 0.0, max_r_diff: 7476.806640625, mean_r_diff: 0.6095441579818726
layer 2:  True
hf_array.shape: (1, 128, 4096), megatron_array.shape: (1, 128, 4096)
min_diff: 0.0, max_diff: 0.015625, mean_diff: 0.0002974425442516804
min_r_diff: 0.0, max_r_diff: 14038.0859375, mean_r_diff: 0.8078778386116028
layer 3:  True
hf_array.shape: (1, 128, 4096), megatron_array.shape: (1, 128, 4096)
min_diff: 0.0, max_diff: 0.03125, mean_diff: 0.0004942230298183858
min_r_diff: 0.0, max_r_diff: 26855.46875, mean_r_diff: 1.3757848739624023
layer 4:  True
hf_array.shape: (1, 128, 4096), megatron_array.shape: (1, 128, 4096)
mi

In [145]:
output_hf = np.load(f"{cache_dir}/hf_lm_head.npy")
output_megatron = np.load(f"{cache_dir}/megatron_output_layer_o0.npy").transpose(1,0,2)
inspect_output(output_hf, output_megatron)
np.allclose(output_hf, output_megatron, atol=atol, rtol=rtol)

output_hf_label = output_hf.argmax(axis=-1)
output_megatron_label = output_megatron.argmax(axis=-1)
print(np.sum(output_hf_label!=output_megatron_label))
print(output_hf_label[output_hf_label!=output_megatron_label])
print(output_megatron_label[output_hf_label!=output_megatron_label])

hf_array.shape: (1, 128, 184622), megatron_array.shape: (1, 128, 184622)
min_diff: 0.0, max_diff: 0.2265625, mean_diff: 0.010856513865292072
min_r_diff: 0.0, max_r_diff: 23336.587890625, mean_r_diff: 0.08597517013549805
1
[110205]
[100873]


In [123]:
def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
    """Applies Rotary Position Embedding to the query and key tensors.

    Args:
        q (`torch.Tensor`): The query tensor.
        k (`torch.Tensor`): The key tensor.
        cos (`torch.Tensor`): The cosine part of the rotary embedding.
        sin (`torch.Tensor`): The sine part of the rotary embedding.
        position_ids (`torch.Tensor`):
            The position indices of the tokens corresponding to the query and key tensors. For example, this can be
            used to pass offsetted position ids when working with a KV-cache.
        unsqueeze_dim (`int`, *optional*, defaults to 1):
            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
    Returns:
        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
    """
    cos = cos[position_ids].unsqueeze(unsqueeze_dim)
    sin = sin[position_ids].unsqueeze(unsqueeze_dim)
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed

In [124]:
# rotary embedding
cos_hf = np.load(f"{cache_dir}/hf_model.layers.1.self_attn.rotary_emb_o0.npy")
sin_hf = np.load(f"{cache_dir}/hf_model.layers.1.self_attn.rotary_emb_o1.npy")
cos_megatron = np.cos(np.load(f"{cache_dir}/megatron_rotary_pos_emb.npy")).reshape(128, 128)
sin_megatron = np.sin(np.load(f"{cache_dir}/megatron_rotary_pos_emb.npy")).reshape(128, 128)
inspect_output(cos_hf, cos_megatron)
inspect_output(sin_hf, sin_megatron)

hf_array.shape: (128, 128), megatron_array.shape: (128, 128)
min_diff: 0.0, max_diff: 0.0019521117210388184, mean_diff: 0.0004719264688901603
min_r_diff: 0.0, max_r_diff: 0.0038003837689757347, mean_r_diff: 0.000658641045447439
hf_array.shape: (128, 128), megatron_array.shape: (128, 128)
min_diff: 0.0, max_diff: 0.001952826976776123, mean_diff: 0.0003145454975310713
min_r_diff: 0.0, max_r_diff: 0.0038327821530401707, mean_r_diff: 0.0013541332446038723


In [125]:
freq_megatron = np.load(f'{cache_dir}/megatron_rotary_pos_emb.npy')

In [126]:
q = torch.from_numpy(q_hf).cuda().reshape(1, 128, 32, -1).transpose(1, 2)
k = torch.from_numpy(k_hf).cuda().reshape(1, 128, 8, -1).transpose(1, 2)
v = torch.from_numpy(v_hf).cuda().reshape(1, 128, 8, -1).transpose(1, 2)
cos, sin = torch.from_numpy(np.load(f"{cache_dir}/hf_model.layers.0.self_attn.rotary_emb_o0.npy")), torch.from_numpy(np.load(f"{cache_dir}/hf_model.layers.0.self_attn.rotary_emb_o1.npy"))
q, k = apply_rotary_pos_emb(q, k, cos.cuda(), sin.cuda(), position_ids=torch.arange(128, device="cuda").unsqueeze(0))
k = k[:, :, None, :, :].expand(-1, -1, 4, -1, -1).reshape(1, 32, 128, 128)
v = v[:, :, None, :, :].expand(-1, -1, 4, -1, -1).reshape(1, 32, 128, 128)

output = torch.nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=True).permute(0, 2, 1, 3).reshape(1, 128, -1).cpu().numpy()

## Validate TE

In [1]:
import torch
import torch.nn as nn
import transformer_engine.pytorch as te
import numpy as np
import os

os.environ["NVIDIA_TF32_OVERRIDE"] = "0"
os.environ["CUBLAS_WORKSPACE_CONFIG "] = ":4096:8"

# torch.backends.cudnn.allow_tf32 = True
# torch.backends.cuda.matmul.allow_tf32 = True


### rmsnorm

In [88]:


class Emu3RMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        """
        Emu3RMSNorm is equivalent to T5LayerNorm
        """
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        return self.weight * hidden_states.to(input_dtype)

class TorchRMSNorm(nn.Module):
    def __init__(self, in_features, zero_centered_gamma, eps=1e-5):
        super().__init__()

        self.eps = eps
        self.in_features = in_features
        self.zero_centered_gamma = zero_centered_gamma

        initial_value = torch.ones(in_features) if zero_centered_gamma else torch.zeros(in_features)
        self.weight = nn.Parameter(initial_value)
        self.register_parameter("weight", self.weight)

    def forward(self, x):
        norm_x2 = torch.sum(x.float() ** 2, dim=-1, keepdim=True)
        d_x = self.in_features

        rms_x2 = norm_x2 / d_x + self.eps
        r_rms_x = rms_x2 ** (-1.0 / 2)
        x_normed = x * r_rms_x

        w = self.weight.float()
        if self.zero_centered_gamma:
            w = 1 + w
        return (w * x_normed).to(x.dtype)

In [55]:
# case I, float32 input

input = np.load(f"{cache_dir}/hf_model.dropout.npy")
input = torch.from_numpy(input).cuda().bfloat16()

weight = np.load(f"{cache_dir}/megatron_rmsnorm_weight.npy")
weight = torch.from_numpy(weight).cuda().bfloat16()

hf_rms_norm = Emu3RMSNorm(4096, eps=1e-5)
hf_rms_norm.weight.requires_grad = False

te_rms_norm = te.RMSNorm(4096, eps=1e-5, params_dtype=torch.bfloat16)
te_rms_norm.weight.requires_grad = False

torch_rms_norm = torch.nn.RMSNorm(4096, eps=1e-5)
torch_rms_norm.weight.requires_grad = False

torch_rms_norm_2 = TorchRMSNorm(4096, zero_centered_gamma=False, eps=1e-5)
torch_rms_norm_2.weight.requires_grad = False

hf_rms_norm.weight = torch.nn.Parameter(weight)
te_rms_norm.weight = torch.nn.Parameter(weight)
torch_rms_norm.weight = torch.nn.Parameter(weight)
torch_rms_norm_2.weight = torch.nn.Parameter(weight)

with torch.no_grad():
    output_hf = hf_rms_norm(input).cpu().float().numpy()
    output_te = te_rms_norm(input).cpu().float().numpy()
    output_torch = torch_rms_norm(input).cpu().float().numpy()
    output_torch_2 = torch_rms_norm_2(input).cpu().float().numpy()
inspect_output(output_hf, output_te)
inspect_output(output_hf, output_torch)
inspect_output(output_hf, output_torch_2)
inspect_output(output_te, output_torch)
inspect_output(output_te, output_torch_2)

hf_array.shape: (1, 128, 4096), megatron_array.shape: (1, 128, 4096)
min_diff: 0.0, max_diff: 0.015625, mean_diff: 0.00018082487804349512
min_r_diff: 0.0, max_r_diff: 0.007812499068677425, mean_r_diff: 0.0013991205487400293
hf_array.shape: (1, 128, 4096), megatron_array.shape: (1, 128, 4096)
min_diff: 0.0, max_diff: 0.015625, mean_diff: 0.00021006364841014147
min_r_diff: 0.0, max_r_diff: 0.013888886198401451, mean_r_diff: 0.0016162245301529765
hf_array.shape: (1, 128, 4096), megatron_array.shape: (1, 128, 4096)
min_diff: 0.0, max_diff: 0.015625, mean_diff: 0.00018082429596688598
min_r_diff: 0.0, max_r_diff: 0.007812499068677425, mean_r_diff: 0.0013990971492603421
hf_array.shape: (1, 128, 4096), megatron_array.shape: (1, 128, 4096)
min_diff: 0.0, max_diff: 0.015625, mean_diff: 0.0002690280962269753
min_r_diff: 0.0, max_r_diff: 0.014084496535360813, mean_r_diff: 0.0020730604883283377
hf_array.shape: (1, 128, 4096), megatron_array.shape: (1, 128, 4096)
min_diff: 0.0, max_diff: 0.000244140

### qkv Linear

In [89]:
with torch.no_grad():
    input = hf_rms_norm(input)

In [90]:

te_weight = np.load(f'{cache_dir}/megatron_qkv_linear_weight.npy')
hf_q_weight = np.load(f'{cache_dir}/hf_q_linear_weight.npy')


te_linear = te.Linear(6144, 4096, params_dtype=torch.bfloat16, bias=False, device="cuda")
te_linear.weight.requires_grad=False
te_linear.weight = torch.nn.Parameter(torch.from_numpy(te_weight).cuda().bfloat16())

hf_q_linear = torch.nn.Linear(4096, 4096, bias=False, dtype=torch.bfloat16)
hf_q_linear.weight.requires_grad=False
hf_q_linear.weight = torch.nn.Parameter(torch.from_numpy(hf_q_weight).cuda().bfloat16())

hf_q_linear_fp32 = torch.nn.Linear(4096, 4096, bias=False)
hf_q_linear_fp32.weight.requires_grad=False
hf_q_linear_fp32.weight = torch.nn.Parameter(torch.from_numpy(hf_q_weight).cuda())

with torch.no_grad():
    q_te = te_linear(input).reshape(1, 128, 8, -1)[:, :, :, :512].reshape(1, 128, -1)
    q_hf = hf_q_linear(input)
    q_hf_fp32 = hf_q_linear_fp32(input.float()).bfloat16()
inspect_output(q_te.cpu().float().numpy(), q_hf.cpu().float().numpy())
inspect_output(q_te.cpu().float().numpy(), q_hf_fp32.cpu().float().numpy()) 
inspect_output(q_hf.cpu().float().numpy(), q_hf_fp32.cpu().float().numpy())

hf_array.shape: (1, 128, 4096), megatron_array.shape: (1, 128, 4096)
min_diff: 0.0, max_diff: 0.03125, mean_diff: 0.00014284173084888607
min_r_diff: 0.0, max_r_diff: 29.493080139160156, mean_r_diff: 0.002411804161965847
hf_array.shape: (1, 128, 4096), megatron_array.shape: (1, 128, 4096)
min_diff: 0.0, max_diff: 0.03125, mean_diff: 0.00014285619545262307
min_r_diff: 0.0, max_r_diff: 29.493080139160156, mean_r_diff: 0.002412014175206423
hf_array.shape: (1, 128, 4096), megatron_array.shape: (1, 128, 4096)
min_diff: 0.0, max_diff: 0.0078125, mean_diff: 3.199484126525931e-07
min_r_diff: 0.0, max_r_diff: 0.0712229534983635, mean_r_diff: 3.683098839246668e-06


In [91]:
te_q_linear_weight = np.load(f'{cache_dir}/megatron_v_linear_weight.npy')
hf_q_linear_weight = np.load(f'{cache_dir}/hf_v_linear_weight.npy')
inspect_output(te_q_linear_weight, hf_q_linear_weight)

hf_array.shape: (1024, 4096), megatron_array.shape: (1024, 4096)
min_diff: 0.0, max_diff: 0.0, mean_diff: 0.0
min_r_diff: 0.0, max_r_diff: 0.0, mean_r_diff: 0.0


In [92]:
te_q_linear_weight = np.load(f'{cache_dir}/megatron_qkv_linear_weight.npy').reshape(8, -1, 4096)[:, :512, :].reshape(4096, 4096)
hf_q_linear_weight = np.load(f'{cache_dir}/hf_q_linear_weight.npy')
inspect_output(te_q_linear_weight, hf_q_linear_weight)

hf_array.shape: (4096, 4096), megatron_array.shape: (4096, 4096)
min_diff: 0.0, max_diff: 0.0, mean_diff: 0.0
min_r_diff: 0.0, max_r_diff: 0.0, mean_r_diff: 0.0


In [60]:
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False
te_linear_q = te.Linear(4096, 4096, params_dtype=torch.bfloat16, bias=False, device="cuda")
te_linear_q.weight.requires_grad=False
te_linear_q.weight = torch.nn.Parameter(torch.from_numpy(hf_q_weight).cuda().bfloat16())
with torch.no_grad():
    q_te = te_linear_q(input)
inspect_output(q_te.cpu().float().numpy(), q_hf.cpu().float().numpy())

hf_array.shape: (1, 128, 4096), megatron_array.shape: (1, 128, 4096)
min_diff: 0.0, max_diff: 0.0, mean_diff: 0.0
min_r_diff: 0.0, max_r_diff: 0.0, mean_r_diff: 0.0


In [61]:
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False
torch_linear_qkv = torch.nn.Linear(6144, 4096, bias=False, dtype=torch.bfloat16)
torch_linear_qkv.weight.requires_grad=False
torch_linear_qkv.weight = torch.nn.Parameter(torch.from_numpy(te_weight).cuda().bfloat16())
with torch.no_grad():
    q_torch = torch_linear_qkv(input).reshape(1, 128, 8, -1)[:, :, :, :512].reshape(1, 128, -1)
inspect_output(q_torch.cpu().float().numpy(), q_hf.cpu().float().numpy())

hf_array.shape: (1, 128, 4096), megatron_array.shape: (1, 128, 4096)
min_diff: 0.0, max_diff: 0.03125, mean_diff: 0.0002010889584198594
min_r_diff: 0.0, max_r_diff: 118.507568359375, mean_r_diff: 0.004296050872653723


### Attention

In [78]:
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "0"

In [85]:
from typing import Optional
import math

from transformer_engine.pytorch.utils import attention_mask_func

class TorchScaledMaskedSoftmax(nn.Module):
    def __init__(self) -> None:
        super().__init__()

    def forward(
        self, inp: torch.Tensor, mask: torch.Tensor, scale: Optional[float] = None
    ) -> torch.Tensor:
        dtype = inp.dtype
        inp = inp.float()

        if scale is not None:
            inp = inp * scale
        mask_output = attention_mask_func(inp, mask) if mask is not None else inp

        probs = torch.nn.Softmax(dim=-1)(mask_output)
        probs = probs.to(dtype)
        return probs

class TorchDotProductAttention(torch.nn.Module):
    def __init__(
        self,
        kv_channels: int,
        attention_dropout: float = 0.0,
    ) -> None:
        super().__init__()

        self.norm_factor = math.sqrt(kv_channels)
        self.scale_mask_softmax = TorchScaledMaskedSoftmax()
        self.attention_dropout = torch.nn.Dropout(attention_dropout)

    def forward(
        self,
        query_layer: torch.Tensor,
        key_layer: torch.Tensor,
        value_layer: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        batch_size, seqlen = query_layer.shape[1], query_layer.shape[0]

        # [b, np, sq, sk]
        output_size = (
            query_layer.size(1),
            query_layer.size(2),
            query_layer.size(0),
            key_layer.size(0),
        )

        # [sq, b, np, hn] -> [sq, b * np, hn]
        query_layer = query_layer.reshape(output_size[2], output_size[0] * output_size[1], -1)
        # [sk, b, np, hn] -> [sk, b * np, hn]
        key_layer = key_layer.reshape(output_size[3], output_size[0] * output_size[1], -1)

        # preallocting result tensor: [b * np, sq, sk]
        matmul_result = torch.empty(
            output_size[0] * output_size[1],
            output_size[2],
            output_size[3],
            dtype=query_layer.dtype,
            device=torch.cuda.current_device(),
        )

        # Raw attention scores. [b * np, sq, sk]
        matmul_result = torch.baddbmm(
            matmul_result,
            query_layer.transpose(0, 1),  # [b * np, sq, hn]
            key_layer.transpose(0, 1).transpose(1, 2),  # [b * np, hn, sk]
            beta=0.0,
            alpha=(1.0 / self.norm_factor),
        )

        # change view to [b, np, sq, sk]
        attention_scores = matmul_result.view(*output_size)

        # attention scores and attention mask [b, np, sq, sk]
        attention_probs = self.scale_mask_softmax(attention_scores, attention_mask)
        attention_probs = self.attention_dropout(attention_probs)

        # value_layer -> context layer.
        # [sk, b, np, hn] --> [b, np, sq, hn]
        output_size = (
            value_layer.size(1),
            value_layer.size(2),
            query_layer.size(0),
            value_layer.size(3),
        )

        # change view [sk, b * np, hn]
        value_layer = value_layer.reshape(value_layer.size(0), output_size[0] * output_size[1], -1)

        # change view [b * np, sq, sk]
        attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)
        

        # matmul: [b * np, sq, hn]
        context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))

        # change view [b, np, sq, hn]
        context_layer = context_layer.view(*output_size)

        # [b, np, sq, hn] --> [sq, b, np, hn]
        context_layer = context_layer.permute(2, 0, 1, 3).contiguous()

        # [sq, b, np, hn] --> [sq, b, hp]
        context_layer = context_layer.view(seqlen, batch_size, -1)

        return context_layer

In [83]:
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False

In [94]:
q = torch.from_numpy(np.load(f"{cache_dir}/megatron_decoder.layers.0.self_attention.core_attention_input_o0.npy")).cuda().bfloat16()
k = torch.from_numpy(np.load(f"{cache_dir}/megatron_decoder.layers.0.self_attention.core_attention_input_o1.npy")).cuda().bfloat16()
v = torch.from_numpy(np.load(f"{cache_dir}/megatron_decoder.layers.0.self_attention.core_attention_input_o2.npy")).cuda().bfloat16()

attention_mask_megatron = torch.triu(torch.ones((128, 128)), diagonal=1).bool().cuda()
attention_mask_hf = torch.triu(torch.ones((128, 128)), diagonal=1).bool()
attention_mask_hf = attention_mask_hf.unsqueeze(0).unsqueeze(0)
attention_mask_hf = attention_mask_hf.to(device="cuda")
attention_mask_hf = torch.zeros_like(attention_mask_hf, dtype=torch.float32).masked_fill(attention_mask_megatron, -10000).bfloat16()

attn_te = te.DotProductAttention(num_attention_heads=32, num_gqa_groups=8, kv_channels=128, attention_dropout=0.0, attn_mask_type="causal")
output_te = attn_te(q, k, v).transpose(1, 0).cpu().float().numpy()
attn_torch = TorchDotProductAttention(kv_channels=128, attention_dropout=0.0)
k2 = k[:, :, :, None, :].expand(-1, -1, -1, 4, -1).reshape(128, 1, 32, 128)
v2 = v[:, :, :, None, :].expand(-1, -1, -1, 4, -1).reshape(128, 1, 32, 128)
output_torch2 = attn_torch(q, k2, v2, attention_mask=attention_mask_megatron).transpose(1, 0).cpu().float().numpy()


# S B H D -> B H S D
q = q.permute(1, 2, 0, 3)
k = k.permute(1, 2, 0, 3)
v = v.permute(1, 2, 0, 3)
k = k[:, :, None, :, :].expand(-1, -1, 4, -1, -1).reshape(1, 32, 128, 128)
v = v[:, :, None, :, :].expand(-1, -1, 4, -1, -1).reshape(1, 32, 128, 128)
from torch.nn.attention import SDPBackend, sdpa_kernel
# Only enable flash attention backend
with sdpa_kernel([SDPBackend.MATH]):
    output_torch = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=~attention_mask_megatron, dropout_p=0.0).permute(1, 0, 2, 3).reshape(1, 128, -1).cpu().float().numpy()
# results = torch.empty(
#     q.size(0)*q.size(1),
#     q.size(2),
#     k.size(2),
#     dtype=q.dtype,
#     device=q.device,
# )
# attn_weights = torch.baddbmm(results, q.reshape(q.size(0)*q.size(1), q.size(2), -1), k.transpose(2, 3).reshape(k.size(0)*k.size(1), k.size(2), -1), beta=0.0, alpha=1.0 / math.sqrt(128))
# attn_weights = attn_weights.view(q.size(0), q.size(1), q.size(2), k.size(2))
# attn_weights = attn_weights.masked_fill(attention_mask_megatron, -10000)


# # upcast attention to fp32
# attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype)
# output_torch = torch.matmul(attn_weights, v)
# output_torch = output_torch.permute(0, 2, 1, 3).reshape(1, 128, -1).cpu().float().numpy()
inspect_output(output_te, output_torch)
inspect_output(output_te, output_torch2)
inspect_output(output_torch, output_torch2)


hf_array.shape: (1, 128, 4096), megatron_array.shape: (1, 128, 4096)
min_diff: 0.0, max_diff: 0.214874267578125, mean_diff: 0.00832381658256054
min_r_diff: 0.0, max_r_diff: 79489.46875, mean_r_diff: 7.876452922821045
hf_array.shape: (1, 128, 4096), megatron_array.shape: (1, 128, 4096)
min_diff: 0.0, max_diff: 0.0009765625, mean_diff: 7.922864824649878e-06
min_r_diff: 0.0, max_r_diff: 91.75723266601562, mean_r_diff: 0.009910470806062222
hf_array.shape: (1, 128, 4096), megatron_array.shape: (1, 128, 4096)
min_diff: 0.0, max_diff: 0.214874267578125, mean_diff: 0.008323675021529198
min_r_diff: 0.0, max_r_diff: 80899.7109375, mean_r_diff: 8.07060718536377


In [66]:

# case I, float32 input

input = np.load(f"{cache_dir}/hf_model.dropout.npy")
input = torch.from_numpy(input).cuda().bfloat16()

weight = np.load(f"{cache_dir}/megatron_rmsnorm_weight.npy")
weight = torch.from_numpy(weight).cuda().bfloat16()

hf_rms_norm = Emu3RMSNorm(4096, eps=1e-5)
hf_rms_norm.weight.requires_grad = False

te_rms_norm = te.RMSNorm(4096, eps=1e-5)
te_rms_norm.weight.requires_grad = False

torch_rms_norm = torch.nn.RMSNorm(4096, eps=1e-5)
torch_rms_norm.weight.requires_grad = False

torch_rms_norm_2 = TorchRMSNorm(4096, zero_centered_gamma=False, eps=1e-5)
torch_rms_norm_2.weight.requires_grad = False

hf_rms_norm.weight = torch.nn.Parameter(weight)
te_rms_norm.weight = torch.nn.Parameter(weight)
torch_rms_norm.weight = torch.nn.Parameter(weight)
torch_rms_norm_2.weight = torch.nn.Parameter(weight)

with torch.no_grad():
    with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
        output_hf = hf_rms_norm(input).cpu().float().numpy()
        output_te = te_rms_norm(input).cpu().float().numpy()
        output_torch = torch_rms_norm(input).cpu().float().numpy()
        output_torch_2 = torch_rms_norm_2(input).cpu().float().numpy()
inspect_output(output_hf, output_te)
inspect_output(output_hf, output_torch)
inspect_output(output_hf, output_torch_2)
inspect_output(output_te, output_torch)
inspect_output(output_te, output_torch_2)

hf_array.shape: (1, 128, 4096), megatron_array.shape: (1, 128, 4096)
min_diff: 0.0, max_diff: 0.015625, mean_diff: 0.00018082487804349512
min_r_diff: 0.0, max_r_diff: 0.007812499068677425, mean_r_diff: 0.0013991205487400293
hf_array.shape: (1, 128, 4096), megatron_array.shape: (1, 128, 4096)
min_diff: 0.0, max_diff: 0.011625289916992188, mean_diff: 0.0002484717406332493
min_r_diff: 0.0, max_r_diff: 0.0076371352188289165, mean_r_diff: 0.0019240325782448053
hf_array.shape: (1, 128, 4096), megatron_array.shape: (1, 128, 4096)
min_diff: 0.0, max_diff: 0.015625, mean_diff: 0.00018082429596688598
min_r_diff: 0.0, max_r_diff: 0.007812499068677425, mean_r_diff: 0.0013990971492603421
hf_array.shape: (1, 128, 4096), megatron_array.shape: (1, 128, 4096)
min_diff: 0.0, max_diff: 0.0077877044677734375, mean_diff: 0.00018516556883696467
min_r_diff: 0.0, max_r_diff: 0.004377623554319143, mean_r_diff: 0.0014343515504151583
hf_array.shape: (1, 128, 4096), megatron_array.shape: (1, 128, 4096)
min_diff: 

  self.activation_dtype = torch.get_autocast_gpu_dtype()


In [67]:
def fp32_rms_norm(x, weight):
    x = x.to(torch.float32)
    variance = x.pow(2).mean(-1, keepdim=True)
    weight = weight.to(torch.float32)
    x = x * torch.rsqrt(variance + 1e-5)
    return (weight * x).bfloat16()
output_fp32 = fp32_rms_norm(input, weight).cpu().float().numpy()
inspect_output(output_fp32, output_te)
inspect_output(output_fp32, output_hf)
inspect_output(output_fp32, output_torch)

hf_array.shape: (1, 128, 4096), megatron_array.shape: (1, 128, 4096)
min_diff: 0.0, max_diff: 0.000244140625, mean_diff: 5.820766091346741e-10
min_r_diff: 0.0, max_r_diff: 0.006134953815490007, mean_r_diff: 2.340290450320026e-08
hf_array.shape: (1, 128, 4096), megatron_array.shape: (1, 128, 4096)
min_diff: 0.0, max_diff: 0.015625, mean_diff: 0.00018082429596688598
min_r_diff: 0.0, max_r_diff: 0.007812499068677425, mean_r_diff: 0.0013987725833430886
hf_array.shape: (1, 128, 4096), megatron_array.shape: (1, 128, 4096)
min_diff: 0.0, max_diff: 0.0077877044677734375, mean_diff: 0.0001851654815254733
min_r_diff: 0.0, max_r_diff: 0.004377623554319143, mean_r_diff: 0.0014343486400321126
