## Validate Tokenizer

In [2]:
from megatron.training.tokenizer.tokenizer import _HuggingFaceTokenizer
from transformers import AutoTokenizer

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
tokenizer_megatron = _HuggingFaceTokenizer("/data/models/Emu3-Gen")
tokenizer_huggingface = AutoTokenizer.from_pretrained("/data/models/Emu3-Gen", trust_remote_code=True)

  vision_tokens = [t.strip() for t in open(special_tokens_file).readlines() if len(t.strip()) > 0]


In [4]:
tokenizer_huggingface.eof_token

'<|extra_201|>'

In [5]:
word_list = ['a portrait of young girl.', 'a portrait of young man.']
assert tokenizer_megatron.tokenize(word_list) == [tokenizer_huggingface.encode(word) for word in word_list]


## Validate Model Forward and Backward

In [6]:
import numpy as np
import torch

In [7]:
Dtype = 'bfloat16'
ATOL = {
    'float32': 1e-3,
    'bfloat16': 5e-2,
    'float16': 5e-2,
}
RTOL = {
    'float32': 1.3e-6,
    'bfloat16': 1e-2,
    'float16': 1e-2,
}
atol = ATOL[Dtype]
rtol = RTOL[Dtype]

In [8]:
cache_dir = "/root/Megatron-LM/cache"

def inspect_output(hf_array, megatron_array):
    print(f"hf_array.shape: {hf_array.shape}, megatron_array.shape: {megatron_array.shape}")
    diff = np.abs(hf_array-megatron_array)
    min_diff = diff.min()
    max_diff = diff.max()
    mean_diff = diff.mean()
    print(f"min_diff: {min_diff}, max_diff: {max_diff}, mean_diff: {mean_diff}")

In [9]:
embedding_hf = np.load("/root/Megatron-LM/cache/hf_model.embed_tokens.npy")
embedding_megatron = np.load("/root/Megatron-LM/cache/megatron_embedding.word_embeddings.npy")
embedding_flag = np.allclose(embedding_hf, embedding_megatron, atol=atol, rtol=rtol)
print("embedding: ", embedding_flag)
inspect_output(embedding_hf, embedding_megatron)
# embedding dropout
embedding_dropout_hf = np.load("/root/Megatron-LM/cache/hf_model.dropout.npy")
embedding_dropout_megatron = np.load("/root/Megatron-LM/cache/megatron_embedding.embedding_dropout.npy")
dropout_flag = np.allclose(embedding_dropout_hf, embedding_dropout_megatron.transpose(1,0,2), atol=atol, rtol=rtol)
print("embedding dropout: ", dropout_flag)
inspect_output(embedding_dropout_hf, embedding_dropout_megatron.transpose(1,0,2))

embedding:  True
hf_array.shape: (1, 128, 4096), megatron_array.shape: (1, 128, 4096)
min_diff: 0.0, max_diff: 0.0, mean_diff: 0.0
embedding dropout:  True
hf_array.shape: (1, 128, 4096), megatron_array.shape: (1, 128, 4096)
min_diff: 0.0, max_diff: 0.0, mean_diff: 0.0


In [10]:
# qkv_proj
qkv_megatron = np.load(f"{cache_dir}/megatron_decoder.layers.0.self_attention.linear_qkv_o0.npy")
q_megatron = qkv_megatron.transpose(1,0,2).reshape(1, 128, 8, -1)[:, :, :, :512].reshape(1, 128, -1)
k_megatron = qkv_megatron.transpose(1,0,2).reshape(1, 128, 8, -1)[:, :, :, 512:640].reshape(1, 128, -1)
v_megatron = qkv_megatron.transpose(1,0,2).reshape(1, 128, 8, -1)[:, :, :, 640:].reshape(1, 128, -1)
q_hf = np.load(f"{cache_dir}/hf_model.layers.0.self_attn.q_proj.npy")
k_hf = np.load(f"{cache_dir}/hf_model.layers.0.self_attn.k_proj.npy")
v_hf = np.load(f"{cache_dir}/hf_model.layers.0.self_attn.v_proj.npy")
layer_norm_flag = np.allclose(q_megatron, q_hf, atol=atol, rtol=rtol) 
print("linear q: ", layer_norm_flag)
inspect_output(q_megatron, q_hf)
layer_norm_flag = np.allclose(k_megatron, k_hf, atol=atol, rtol=rtol)
print("linear k: ", layer_norm_flag)
inspect_output(k_megatron, k_hf)
layer_norm_flag = np.allclose(v_megatron, v_hf, atol=atol, rtol=rtol)
print("linear v: ", layer_norm_flag)
inspect_output(v_megatron, v_hf)

linear q:  True
hf_array.shape: (1, 128, 4096), megatron_array.shape: (1, 128, 4096)
min_diff: 0.0, max_diff: 0.03125, mean_diff: 0.00031974565354175866
linear k:  True
hf_array.shape: (1, 128, 1024), megatron_array.shape: (1, 128, 1024)
min_diff: 0.0, max_diff: 0.03125, mean_diff: 0.0004661270650103688
linear v:  True
hf_array.shape: (1, 128, 1024), megatron_array.shape: (1, 128, 1024)
min_diff: 0.0, max_diff: 0.001953125, mean_diff: 7.871985144447535e-05


In [11]:
# attention output
for i in range(32):
    attention_output_hf = np.load(f"{cache_dir}/hf_model.layers.{i}.self_attn_o0.npy")
    attention_output_megatron = np.load(f"{cache_dir}/megatron_decoder.layers.{i}.self_attention_o0.npy").transpose(1,0,2)
    print(f"layer {i}: ", np.allclose(attention_output_hf, attention_output_megatron, atol=atol, rtol=rtol))
    inspect_output(attention_output_hf, attention_output_megatron)

layer 0:  True
hf_array.shape: (1, 128, 4096), megatron_array.shape: (1, 128, 4096)
min_diff: 0.0, max_diff: 0.0009765625, mean_diff: 5.376178251026431e-06
layer 1:  True
hf_array.shape: (1, 128, 4096), megatron_array.shape: (1, 128, 4096)
min_diff: 0.0, max_diff: 0.015625, mean_diff: 0.00011902677942998707
layer 2:  True
hf_array.shape: (1, 128, 4096), megatron_array.shape: (1, 128, 4096)
min_diff: 0.0, max_diff: 0.03125, mean_diff: 0.0008853924227878451
layer 3:  False
hf_array.shape: (1, 128, 4096), megatron_array.shape: (1, 128, 4096)
min_diff: 0.0, max_diff: 0.228515625, mean_diff: 0.002636928576976061
layer 4:  False
hf_array.shape: (1, 128, 4096), megatron_array.shape: (1, 128, 4096)
min_diff: 0.0, max_diff: 0.1171875, mean_diff: 0.0017968036700040102
layer 5:  False
hf_array.shape: (1, 128, 4096), megatron_array.shape: (1, 128, 4096)
min_diff: 0.0, max_diff: 0.09375, mean_diff: 0.001739551080390811
layer 6:  False
hf_array.shape: (1, 128, 4096), megatron_array.shape: (1, 128, 4

In [12]:
output_hf = np.load(f"{cache_dir}/hf_lm_head.npy")
output_megatron = np.load(f"{cache_dir}/megatron_output_layer_o0.npy").transpose(1,0,2)
inspect_output(output_hf, output_megatron)
np.allclose(output_hf, output_megatron, atol=atol, rtol=rtol)

output_hf_label = output_hf.argmax(axis=-1)
output_megatron_label = output_megatron.argmax(axis=-1)
print(np.sum(output_hf_label!=output_megatron_label))

hf_array.shape: (1, 128, 184622), megatron_array.shape: (1, 128, 184622)


min_diff: 0.0, max_diff: 2.59375, mean_diff: 0.09347893297672272
11


In [13]:
def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
    """Applies Rotary Position Embedding to the query and key tensors.

    Args:
        q (`torch.Tensor`): The query tensor.
        k (`torch.Tensor`): The key tensor.
        cos (`torch.Tensor`): The cosine part of the rotary embedding.
        sin (`torch.Tensor`): The sine part of the rotary embedding.
        position_ids (`torch.Tensor`):
            The position indices of the tokens corresponding to the query and key tensors. For example, this can be
            used to pass offsetted position ids when working with a KV-cache.
        unsqueeze_dim (`int`, *optional*, defaults to 1):
            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
    Returns:
        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
    """
    cos = cos[position_ids].unsqueeze(unsqueeze_dim)
    sin = sin[position_ids].unsqueeze(unsqueeze_dim)
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed

In [14]:
# rotary embedding
cos_hf = np.load(f"{cache_dir}/hf_model.layers.1.self_attn.rotary_emb_o0.npy")
sin_hf = np.load(f"{cache_dir}/hf_model.layers.1.self_attn.rotary_emb_o1.npy")
cos_megatron = np.cos(np.load(f"{cache_dir}/megatron_rotary_pos_emb.npy")).reshape(128, 128)
sin_megatron = np.sin(np.load(f"{cache_dir}/megatron_rotary_pos_emb.npy")).reshape(128, 128)
inspect_output(cos_hf, cos_megatron)
inspect_output(sin_hf, sin_megatron)

hf_array.shape: (128, 128), megatron_array.shape: (128, 128)
min_diff: 0.0, max_diff: 0.0019521117210388184, mean_diff: 0.0004719264688901603
hf_array.shape: (128, 128), megatron_array.shape: (128, 128)
min_diff: 0.0, max_diff: 0.001952826976776123, mean_diff: 0.0003145454975310713


In [15]:
freq_megatron = np.load(f'{cache_dir}/megatron_rotary_pos_emb.npy')

In [16]:
q = torch.from_numpy(q_hf).cuda().reshape(1, 128, 32, -1).transpose(1, 2)
k = torch.from_numpy(k_hf).cuda().reshape(1, 128, 8, -1).transpose(1, 2)
v = torch.from_numpy(v_hf).cuda().reshape(1, 128, 8, -1).transpose(1, 2)
cos, sin = torch.from_numpy(np.load(f"{cache_dir}/hf_model.layers.0.self_attn.rotary_emb_o0.npy")), torch.from_numpy(np.load(f"{cache_dir}/hf_model.layers.0.self_attn.rotary_emb_o1.npy"))
q, k = apply_rotary_pos_emb(q, k, cos.cuda(), sin.cuda(), position_ids=torch.arange(128, device="cuda").unsqueeze(0))
k = k[:, :, None, :, :].expand(-1, -1, 4, -1, -1).reshape(1, 32, 128, 128)
v = v[:, :, None, :, :].expand(-1, -1, 4, -1, -1).reshape(1, 32, 128, 128)

output = torch.nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=True).permute(0, 2, 1, 3).reshape(1, 128, -1).cpu().numpy()

In [17]:
inspect_output(output_hf, output_megatron)

hf_array.shape: (1, 128, 184622), megatron_array.shape: (1, 128, 184622)
min_diff: 0.0, max_diff: 2.59375, mean_diff: 0.09347893297672272


In [18]:
rms_norm_output_megatron = np.load(f"{cache_dir}/megatron_rms_norm.npy")
rms_norm_output_hf = np.load(f"{cache_dir}/hf_rms_norm.npy")
inspect_output(rms_norm_output_hf, rms_norm_output_megatron)

hf_array.shape: (1, 128, 4096), megatron_array.shape: (128, 4096)
min_diff: 0.0, max_diff: 0.000244140625, mean_diff: 5.820766091346741e-10


## Validate TE

In [19]:
import torch
import torch.nn as nn
import transformer_engine.pytorch as te
import numpy as np
import os

os.environ["NVIDIA_TF32_OVERRIDE"] = "0"
os.environ["CUBLAS_WORKSPACE_CONFIG "] = ":16:8"

# torch.backends.cudnn.allow_tf32 = True
# torch.backends.cuda.matmul.allow_tf32 = True


### rmsnorm

In [34]:


class Emu3RMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        """
        Emu3RMSNorm is equivalent to T5LayerNorm
        """
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        return self.weight * hidden_states.to(input_dtype)

class TorchRMSNorm(nn.Module):
    def __init__(self, in_features, zero_centered_gamma, eps=1e-5):
        super().__init__()

        self.eps = eps
        self.in_features = in_features
        self.zero_centered_gamma = zero_centered_gamma

        initial_value = torch.ones(in_features) if zero_centered_gamma else torch.zeros(in_features)
        self.weight = nn.Parameter(initial_value)
        self.register_parameter("weight", self.weight)

    def forward(self, x):
        norm_x2 = torch.sum(x.float() ** 2, dim=-1, keepdim=True)
        d_x = self.in_features

        rms_x2 = norm_x2 / d_x + self.eps
        r_rms_x = rms_x2 ** (-1.0 / 2)
        x_normed = x * r_rms_x

        w = self.weight.float()
        if self.zero_centered_gamma:
            w = 1 + w
        return (w * x_normed).to(x.dtype)

In [21]:
# case I, float32 input

input = np.load(f"{cache_dir}/hf_model.dropout.npy")
input = torch.from_numpy(input).cuda().bfloat16()

weight = np.load(f"{cache_dir}/megatron_rmsnorm_weight.npy")
weight = torch.from_numpy(weight).cuda().bfloat16()

hf_rms_norm = Emu3RMSNorm(4096, eps=1e-5)
hf_rms_norm.weight.requires_grad = False

te_rms_norm = te.RMSNorm(4096, eps=1e-5, params_dtype=torch.bfloat16)
te_rms_norm.weight.requires_grad = False

torch_rms_norm = torch.nn.RMSNorm(4096, eps=1e-5)
torch_rms_norm.weight.requires_grad = False

torch_rms_norm_2 = TorchRMSNorm(4096, zero_centered_gamma=False, eps=1e-5)
torch_rms_norm_2.weight.requires_grad = False

hf_rms_norm.weight = torch.nn.Parameter(weight)
te_rms_norm.weight = torch.nn.Parameter(weight)
torch_rms_norm.weight = torch.nn.Parameter(weight)
torch_rms_norm_2.weight = torch.nn.Parameter(weight)

with torch.no_grad():
    output_hf = hf_rms_norm(input).cpu().float().numpy()
    output_te = te_rms_norm(input).cpu().float().numpy()
    output_torch = torch_rms_norm(input).cpu().float().numpy()
    output_torch_2 = torch_rms_norm_2(input).cpu().float().numpy()
inspect_output(output_hf, output_te)
inspect_output(output_hf, output_torch)
inspect_output(output_hf, output_torch_2)
inspect_output(output_te, output_torch)
inspect_output(output_te, output_torch_2)

hf_array.shape: (1, 128, 4096), megatron_array.shape: (1, 128, 4096)
min_diff: 0.0, max_diff: 0.015625, mean_diff: 0.00018082487804349512
hf_array.shape: (1, 128, 4096), megatron_array.shape: (1, 128, 4096)
min_diff: 0.0, max_diff: 0.015625, mean_diff: 0.00021006364841014147
hf_array.shape: (1, 128, 4096), megatron_array.shape: (1, 128, 4096)
min_diff: 0.0, max_diff: 0.015625, mean_diff: 0.00018082429596688598
hf_array.shape: (1, 128, 4096), megatron_array.shape: (1, 128, 4096)
min_diff: 0.0, max_diff: 0.015625, mean_diff: 0.0002690280962269753
hf_array.shape: (1, 128, 4096), megatron_array.shape: (1, 128, 4096)
min_diff: 0.0, max_diff: 0.000244140625, mean_diff: 5.820766091346741e-10


### qkv Linear

In [22]:
with torch.no_grad():
    input = hf_rms_norm(input)

In [35]:

te_weight = np.load(f'{cache_dir}/megatron_qkv_linear_weight.npy')
hf_q_weight = np.load(f'{cache_dir}/hf_q_linear_weight.npy')


te_linear = te.Linear(6144, 4096, params_dtype=torch.bfloat16, bias=False, device="cuda")
te_linear.weight.requires_grad=False
te_linear.weight = torch.nn.Parameter(torch.from_numpy(te_weight).cuda().bfloat16())

hf_q_linear = torch.nn.Linear(4096, 4096, bias=False, dtype=torch.bfloat16)
hf_q_linear.weight.requires_grad=False
hf_q_linear.weight = torch.nn.Parameter(torch.from_numpy(hf_q_weight).cuda().bfloat16())

hf_q_linear_fp32 = torch.nn.Linear(4096, 4096, bias=False)
hf_q_linear_fp32.weight.requires_grad=False
hf_q_linear_fp32.weight = torch.nn.Parameter(torch.from_numpy(hf_q_weight).cuda())

with torch.no_grad():
    q_te = te_linear(input).reshape(1, 128, 8, -1)[:, :, :, :512].reshape(1, 128, -1)
    q_hf = hf_q_linear(input)
    q_hf_fp32 = hf_q_linear_fp32(input.float()).bfloat16()
inspect_output(q_te.cpu().float().numpy(), q_hf.cpu().float().numpy())
inspect_output(q_te.cpu().float().numpy(), q_hf_fp32.cpu().float().numpy()) 
inspect_output(q_hf.cpu().float().numpy(), q_hf_fp32.cpu().float().numpy())

hf_array.shape: (1, 128, 4096), megatron_array.shape: (1, 128, 4096)
min_diff: 0.0, max_diff: 0.03125, mean_diff: 0.00014284173084888607
hf_array.shape: (1, 128, 4096), megatron_array.shape: (1, 128, 4096)
min_diff: 0.0, max_diff: 0.03125, mean_diff: 0.00014285619545262307
hf_array.shape: (1, 128, 4096), megatron_array.shape: (1, 128, 4096)
min_diff: 0.0, max_diff: 0.0078125, mean_diff: 3.199484126525931e-07


In [36]:
te_q_linear_weight = np.load(f'{cache_dir}/megatron_qkv_linear_weight.npy').reshape(8, -1, 4096)[:, :512, :].reshape(4096, 4096)
hf_q_linear_weight = np.load(f'{cache_dir}/hf_q_linear_weight.npy')
inspect_output(te_q_linear_weight, hf_q_linear_weight)

hf_array.shape: (4096, 4096), megatron_array.shape: (4096, 4096)
min_diff: 0.0, max_diff: 0.0, mean_diff: 0.0


In [37]:
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False
te_linear_q = te.Linear(4096, 4096, params_dtype=torch.bfloat16, bias=False, device="cuda")
te_linear_q.weight.requires_grad=False
te_linear_q.weight = torch.nn.Parameter(torch.from_numpy(hf_q_weight).cuda().bfloat16())
with torch.no_grad():
    q_te = te_linear_q(input)
inspect_output(q_te.cpu().float().numpy(), q_hf.cpu().float().numpy())

hf_array.shape: (1, 128, 4096), megatron_array.shape: (1, 128, 4096)
min_diff: 0.0, max_diff: 0.03125, mean_diff: 0.00020105975272599608


In [43]:
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False
torch_linear_qkv = torch.nn.Linear(6144, 4096, bias=False, dtype=torch.bfloat16)
torch_linear_qkv.weight.requires_grad=False
torch_linear_qkv.weight = torch.nn.Parameter(torch.from_numpy(te_weight).cuda().bfloat16())
with torch.no_grad():
    q_torch = torch_linear_qkv(input).reshape(1, 128, 8, -1)[:, :, :, :512].reshape(1, 128, -1)
inspect_output(q_torch.cpu().float().numpy(), q_hf.cpu().float().numpy())

hf_array.shape: (1, 128, 4096), megatron_array.shape: (1, 128, 4096)
min_diff: 0.0, max_diff: 0.0078125, mean_diff: 6.887059385007888e-07


### Attention

In [41]:
q = q.permute(2, 0, 1, 3)
k = k.permute(2, 0, 1, 3)
v = v.permute(2, 0, 1, 3)

In [55]:
q_megatron = torch.from_numpy(np.load(f"{cache_dir}/megatron_decoder.layers.0.self_attention.core_attention_input_o0.npy")).cuda()
k_megatron = torch.from_numpy(np.load(f"{cache_dir}/megatron_decoder.layers.0.self_attention.core_attention_input_o1.npy")).cuda()
v_megatron = torch.from_numpy(np.load(f"{cache_dir}/megatron_decoder.layers.0.self_attention.core_attention_input_o2.npy")).cuda()


In [None]:
attn = te.DotProductAttention(num_attention_heads=32, kv_channels=128, attention_dropout=0.0, attn_mask_type="causal")
output_te = attn(q, k, v).transpose(1, 0).cpu().numpy()

In [66]:
attn = te.DotProductAttention(num_attention_heads=32, kv_channels=128, num_gqa_groups=8, attention_dropout=0.0, attn_mask_type="causal")
output_te = attn(q_megatron, k_megatron, v_megatron).transpose(1, 0).cpu().numpy()

In [None]:
output[0,10]

In [None]:
output_te[0,10]

In [None]:
core_attn_output_megatron[0,10]

In [52]:

# case I, float32 input

input = np.load(f"{cache_dir}/hf_model.dropout.npy")
input = torch.from_numpy(input).cuda().bfloat16()

weight = np.load(f"{cache_dir}/megatron_rmsnorm_weight.npy")
weight = torch.from_numpy(weight).cuda().bfloat16()

hf_rms_norm = Emu3RMSNorm(4096, eps=1e-5)
hf_rms_norm.weight.requires_grad = False

te_rms_norm = te.RMSNorm(4096, eps=1e-5)
te_rms_norm.weight.requires_grad = False

torch_rms_norm = torch.nn.RMSNorm(4096, eps=1e-5)
torch_rms_norm.weight.requires_grad = False

torch_rms_norm_2 = TorchRMSNorm(4096, zero_centered_gamma=False, eps=1e-5)
torch_rms_norm_2.weight.requires_grad = False

hf_rms_norm.weight = torch.nn.Parameter(weight)
te_rms_norm.weight = torch.nn.Parameter(weight)
torch_rms_norm.weight = torch.nn.Parameter(weight)
torch_rms_norm_2.weight = torch.nn.Parameter(weight)

with torch.no_grad():
    with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
        output_hf = hf_rms_norm(input).cpu().float().numpy()
        output_te = te_rms_norm(input).cpu().float().numpy()
        output_torch = torch_rms_norm(input).cpu().float().numpy()
        output_torch_2 = torch_rms_norm_2(input).cpu().float().numpy()
inspect_output(output_hf, output_te)
inspect_output(output_hf, output_torch)
inspect_output(output_hf, output_torch_2)
inspect_output(output_te, output_torch)
inspect_output(output_te, output_torch_2)

hf_array.shape: (1, 128, 4096), megatron_array.shape: (1, 128, 4096)
min_diff: 0.0, max_diff: 0.015625, mean_diff: 0.00018082487804349512
hf_array.shape: (1, 128, 4096), megatron_array.shape: (1, 128, 4096)
min_diff: 0.0, max_diff: 0.011625289916992188, mean_diff: 0.0002484717406332493
hf_array.shape: (1, 128, 4096), megatron_array.shape: (1, 128, 4096)
min_diff: 0.0, max_diff: 0.015625, mean_diff: 0.00018082429596688598
hf_array.shape: (1, 128, 4096), megatron_array.shape: (1, 128, 4096)
min_diff: 0.0, max_diff: 0.0077877044677734375, mean_diff: 0.00018516556883696467
hf_array.shape: (1, 128, 4096), megatron_array.shape: (1, 128, 4096)
min_diff: 0.0, max_diff: 0.000244140625, mean_diff: 5.820766091346741e-10


  self.activation_dtype = torch.get_autocast_gpu_dtype()


In [51]:
def fp32_rms_norm(x, weight):
    x = x.to(torch.float32)
    variance = x.pow(2).mean(-1, keepdim=True)
    weight = weight.to(torch.float32)
    x = x * torch.rsqrt(variance + 1e-5)
    return (weight * x).bfloat16()
output_fp32 = fp32_rms_norm(input, weight).cpu().float().numpy()
inspect_output(output_fp32, output_te)
inspect_output(output_fp32, output_hf)
inspect_output(output_fp32, output_torch)

hf_array.shape: (1, 128, 4096), megatron_array.shape: (1, 128, 4096)
min_diff: 0.0, max_diff: 10.078125, mean_diff: 0.0809398666024208
hf_array.shape: (1, 128, 4096), megatron_array.shape: (1, 128, 4096)
min_diff: 0.0, max_diff: 10.078125, mean_diff: 0.08094507455825806
hf_array.shape: (1, 128, 4096), megatron_array.shape: (1, 128, 4096)
min_diff: 0.0, max_diff: 10.078125, mean_diff: 0.08098115772008896


In [37]:
output_hf

array([[[-0.20117188,  0.06689453,  0.02563477, ...,  0.27539062,
         -0.20507812,  0.15234375],
        [ 0.15039062, -0.37109375,  0.34570312, ...,  0.44726562,
         -0.11279297, -0.05859375],
        [ 0.00604248, -0.01586914,  0.09619141, ..., -0.2890625 ,
         -0.01452637,  0.140625  ],
        ...,
        [-0.33007812,  0.01037598,  0.12011719, ..., -0.20605469,
         -0.02709961, -0.00515747],
        [-0.296875  ,  0.10498047,  0.20800781, ..., -0.21484375,
          0.06835938,  0.46289062],
        [ 0.13183594, -0.09814453,  0.18847656, ..., -0.11767578,
         -0.12207031, -0.515625  ]]], dtype=float32)

In [36]:
output_te

array([[[-0.20214844,  0.06689453,  0.02563477, ...,  0.27539062,
         -0.20507812,  0.15234375],
        [ 0.15039062, -0.37109375,  0.34570312, ...,  0.44726562,
         -0.11279297, -0.05859375],
        [ 0.006073  , -0.01586914,  0.09619141, ..., -0.28710938,
         -0.01452637,  0.14160156],
        ...,
        [-0.33007812,  0.01031494,  0.12011719, ..., -0.20605469,
         -0.02722168, -0.00515747],
        [-0.296875  ,  0.10498047,  0.20703125, ..., -0.21484375,
          0.06884766,  0.46289062],
        [ 0.13183594, -0.09765625,  0.18847656, ..., -0.11767578,
         -0.12207031, -0.51953125]]], dtype=float32)