# Proving conversion of megatron checkpoint to huggingface transformer checkpoint

In [1]:
from transformers import AutoModelForCausalLM
import torch
import os
import sys
import warnings
warnings.filterwarnings('ignore') # to suppress warnings from megatron torch.load
# you don't actually need megatron installed, but it needs the package to deserialize the model
# this was done on a mac

# sys.path.append("PATH TO MEGATRON") uncommented as running in directory

reference_model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B", device_map = "cpu", torch_dtype = "bfloat16", low_cpu_mem_usage = True)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [2]:
BASE_FOLDER = "../megatron-llama3-8B-PP4-TP8-mcore/iter_0000001"
# to get this you need to convert the llama model to megatron format first
FILE_FORMAT = "mp_rank_{tensor:02d}_{pipeline:03d}/model_optim_rng.pt"

# check exists:
assert os.path.exists(BASE_FOLDER)

## Embedding layer

In [3]:
# usually resides first pipeline
embedding = []
for i in range(8):
    to_add = torch.load(f"{BASE_FOLDER}/{FILE_FORMAT.format(tensor = i, pipeline = 0)}")['model']['embedding.word_embeddings.weight']
    if i==0:
        print(to_add.shape)
    embedding.append(to_add)

embedding = torch.cat(embedding, dim=0)
print(f"{embedding.shape=}")

torch.Size([16128, 4096])
embedding.shape=torch.Size([129024, 4096])


In [4]:
# megatron creates additional embed weights

x, y = reference_model.model.embed_tokens.weight.data.shape
print("All duplicates after x: ", end ="")
print(
    all(
        [torch.equal(embedding[x], _x ) for _x in embedding[x:, :]]
        )
)
embedding = embedding[:x,:]
print("Is equal to huggingface model: ", end="")
print(torch.equal(embedding, reference_model.model.embed_tokens.weight.data))

All duplicates after x: True
Is equal to huggingface model: True


## Attention

In [5]:
# helper variables
hidden_size = reference_model.config.hidden_size
num_heads = reference_model.config.num_attention_heads
n_layers = reference_model.config.num_hidden_layers
gqa_head = reference_model.config.num_key_value_heads
ffn_size = reference_model.config.intermediate_size
dim = hidden_size // num_heads

In [6]:
def fused_to_qkv(fused, nh, ng, dim):
    """
    split fused qkv into q, k, v

    Args:
        fused: [b, s, dim*3*nh]
        nh: number of heads
        ng: number of groups
        dim: kv channels

    Returns:
        q,k,v
    """
    hidden_size = dim * nh
    reshaped = fused.reshape(ng, dim*nh//ng + 2 *dim, -1)
    q,k,v = torch.split(reshaped, [dim*nh//ng, dim, dim], dim=1)
    return q.reshape(-1, hidden_size), k.reshape(-1, hidden_size), v.reshape(-1, hidden_size)

def fused_mlp_to_gate_up(fused_tensor,ffn_hidden_size):
    """
    Spit into gate and up_proj

    Returns:
        gate, up_proj
    """
    gate, up_proj = torch.split(fused_tensor, [ffn_hidden_size, ffn_hidden_size], dim=0)
    return gate, up_proj


### QKV

In [8]:
# load one attention layer
attn = []
for i in range(8):

    to_add = torch.load(f"{BASE_FOLDER}/{FILE_FORMAT.format(tensor = i, pipeline = 0)}")['model']['decoder.layers.0.self_attention.linear_qkv.weight']
    print(to_add.shape)
    attn.append(to_add)

attn = torch.cat(attn, dim=0)

torch.Size([768, 4096])
torch.Size([768, 4096])
torch.Size([768, 4096])
torch.Size([768, 4096])
torch.Size([768, 4096])
torch.Size([768, 4096])
torch.Size([768, 4096])
torch.Size([768, 4096])


In [9]:
q,k,v = fused_to_qkv(attn, num_heads, gqa_head, dim)
print(torch.equal(q, reference_model.model.layers[0].self_attn.q_proj.weight.data))
print(torch.equal(k, reference_model.model.layers[0].self_attn.k_proj.weight.data))
print(torch.equal(v, reference_model.model.layers[0].self_attn.v_proj.weight.data))

True
True
True


### O proj

In [11]:
# linear proj -> o_proj
o_proj = []
for i in range(8):
    to_add = torch.load(f"{BASE_FOLDER}/{FILE_FORMAT.format(tensor = i, pipeline = 0)}")['model']['decoder.layers.0.self_attention.linear_proj.weight']
    o_proj.append(to_add)

o_proj = torch.cat(o_proj, dim=1)
print(o_proj.shape)

torch.Size([4096, 4096])


In [12]:
torch.equal(o_proj, reference_model.model.layers[0].self_attn.o_proj.weight)

True

### MLP

#### FC1 -> up_proj, date_proj

Megatron saves FC1 as gate_tp1,up_tp1, gate_tp2, up_tp2 ... gate_tpN, up_tpN

In [19]:

gate_proj = []
up_proj = []
for i in range(8):
    path = f"{BASE_FOLDER}/{FILE_FORMAT.format(tensor = i, pipeline = 0)}"
    fused = torch.load(path)['model']['decoder.layers.0.mlp.linear_fc1.weight']
    _gate_proj, _up_proj = fused_mlp_to_gate_up(fused, ffn_size // 8)
    gate_proj.append(_gate_proj)
    up_proj.append(_up_proj)

gate_proj = torch.cat(gate_proj, dim=0)
up_proj = torch.cat(up_proj, dim=0)
print(f"{gate_proj.shape=}")
print(f"{up_proj.shape=}")

gate_proj.shape=torch.Size([14336, 4096])
up_proj.shape=torch.Size([14336, 4096])


In [15]:
print(torch.equal(gate_proj, reference_model.model.layers[0].mlp.gate_proj.weight.data))
print(torch.equal(up_proj, reference_model.model.layers[0].mlp.up_proj.weight.data))

True
True


### FC2 -> down_proj

In [20]:
down_proj = []
for i in range(8):
    path = f"{BASE_FOLDER}/{FILE_FORMAT.format(tensor = i, pipeline = 0)}"
    to_add = torch.load(path)['model']['decoder.layers.0.mlp.linear_fc2.weight']
    down_proj.append(to_add)

down_proj = torch.cat(down_proj, dim=1)
print(f"{down_proj.shape=}")

down_proj.shape=torch.Size([4096, 14336])


In [21]:
print(torch.equal(down_proj, reference_model.model.layers[0].mlp.down_proj.weight.data))

True


### Norms

This was a little confusing but qkv_norm is referring to layer norm before qkv => input_layernorm
qkv_norm -> input_layernorm

post_attention is fc1_layernorm

Norms are duplicated across all tensors parallel

In [23]:
input_layernorm = []
for i in range(8):
    path = f"{BASE_FOLDER}/{FILE_FORMAT.format(tensor = i, pipeline = 0)}"
    to_add = torch.load(path)['model']['decoder.layers.0.self_attention.linear_qkv.layer_norm_weight']
    print(to_add.shape)
    input_layernorm.append(to_add)

torch.Size([4096])
torch.Size([4096])
torch.Size([4096])
torch.Size([4096])
torch.Size([4096])
torch.Size([4096])
torch.Size([4096])
torch.Size([4096])


In [27]:
# verify that they are all duplicates
# should be trivial if x=y and x=z then y=z
print("All duplicate weights: ", end = '')
print(all([torch.equal(input_layernorm[0], x) for x in input_layernorm]))

print("\nReference to HF: ", end = '')
print(torch.equal(input_layernorm[0], reference_model.model.layers[0].input_layernorm.weight))

All duplicate weights: True

Reference to HF: True


In [28]:
post_attention_layernorm = []
for i in range(8):
    path = f"{BASE_FOLDER}/{FILE_FORMAT.format(tensor = i, pipeline = 0)}"
    to_add = torch.load(path)['model']['decoder.layers.0.mlp.linear_fc1.layer_norm_weight']
    print(to_add.shape)
    post_attention_layernorm.append(to_add)

torch.Size([4096])
torch.Size([4096])
torch.Size([4096])
torch.Size([4096])
torch.Size([4096])
torch.Size([4096])
torch.Size([4096])
torch.Size([4096])


In [29]:
# verify that they are all duplicates
print("All duplicate weights: ", end = '')
print(all([torch.equal(post_attention_layernorm[0], x) for x in post_attention_layernorm]))

print("\nReference to HF: ", end = '')
print(torch.equal(post_attention_layernorm[0], reference_model.model.layers[0].post_attention_layernorm.weight))

All duplicate weights: True

Reference to HF: True


## Final Layernorm

In [250]:
final_layernorm = []
for i in range(8):
    # note using final pipeline
    to_add = torch.load(f"{BASE_FOLDER}/{FILE_FORMAT.format(i, 3)}")['model']['decoder.final_layernorm.weight']
    print(to_add.shape)
    final_layernorm.append(to_add)

  to_add = torch.load(f"{BASE_FOLDER}/{FILE_FORMAT.format(i, 3)}")['model']['decoder.final_layernorm.weight']


torch.Size([4096])
torch.Size([4096])
torch.Size([4096])
torch.Size([4096])
torch.Size([4096])
torch.Size([4096])
torch.Size([4096])
torch.Size([4096])


In [254]:
# verify that they are all duplicates
print("All duplicate weights: ", end = '')
print(all([torch.equal(final_layernorm[0], x) for x in final_layernorm]))

print("\nReference to HF: ", end = '')
print(torch.equal(final_layernorm[0], reference_model.model.norm.weight))


All duplicate weights: True

Reference to HF: True


## Output layer

In [32]:
output_layer = []
# note that output layer in pipeline 3
for i in range(8):
    path = f"{BASE_FOLDER}/{FILE_FORMAT.format(tensor = i, pipeline = 3)}"
    to_add = torch.load(path)['model']['output_layer.weight']
    print(to_add.shape)
    output_layer.append(to_add)

output_layer = torch.cat(output_layer, dim=0)
print(f"{output_layer.shape=}")

torch.Size([16128, 4096])
torch.Size([16128, 4096])
torch.Size([16128, 4096])
torch.Size([16128, 4096])
torch.Size([16128, 4096])
torch.Size([16128, 4096])
torch.Size([16128, 4096])
torch.Size([16128, 4096])
output_layer.shape=torch.Size([129024, 4096])
