In [3]:
%load_ext autoreload
%autoreload 2

In [4]:
import sys
sys.path.append("/iopsstor/scratch/cscs/xyixuan/Megatron-LM")
sys.path.append("/capstor/users/cscs/xyixuan/PDM")

In [5]:
from transformers import AutoModelForCausalLM, AutoConfig
import torch

# Load the checkpoint
checkpoint_path='/iopsstor/scratch/cscs/xyixuan/Megatron-LM/logs/Meg-Runs/Goldfish/llama3-1b-15n-8192sl-60gbsz-goldfish-no-bos/torch/iter_0170005/mp_rank_00/model_optim_rng.pt'
checkpoint = torch.load(checkpoint_path, weights_only=False, map_location='cpu')

In [10]:
args = checkpoint['args']

config = AutoConfig.for_model(
    architectures=["LlamaForCausalLM"],
    attention_bias=False,
    attention_dropout=args.attention_dropout,
    bos_token_id=128000,
    eos_token_id=128001,
    head_dim=int(args.hidden_size/args.num_attention_heads),
    hidden_act="silu",
    hidden_size=args.hidden_size,
    initializer_range=0.01,
    intermediate_size=args.ffn_hidden_size,
    max_position_embeddings=131072,
    mlp_bias=False,
    model_type="llama",
    num_attention_heads=args.num_attention_heads,
    num_hidden_layers=args.num_layers,
    num_key_value_heads=args.num_query_groups,
    pretraining_tp=1,
    rms_norm_eps=args.norm_epsilon,
    rope_scaling={
        "factor": args.rope_scaling_factor,
        "high_freq_factor": 4.0,
        "low_freq_factor": 1.0,
        "original_max_position_embeddings": args.max_position_embeddings,
        "rope_type": "llama3"
    },
    rope_theta=args.rotary_base,
    tie_word_embeddings=not args.untie_embeddings_and_output_weights,
    torch_dtype=args.params_dtype,
    use_cache=True,
    vocab_size=args.padded_vocab_size
)

In [11]:
def convert_megatron_to_hf_dict(model_dict):
    """
    Implementation adopted from https://github.com/TJ-Solergibert/NeMo/blob/825c246b12e76ee7e9b3cdf01aea9c9dacdc03fe/scripts/checkpoint_converters/convert_llama_nemo_to_hf.py#L106
    ALERT: Currently no support for model parallelism
    """
    from collections import OrderedDict 
    checkpoint = OrderedDict()  # using OrderedDict for consistent ordering
    
    # Get model dimensions
    hidden_size = model_dict['decoder.layers.0.self_attention.linear_qkv.weight'].shape[1]
    head_num = args.num_attention_heads
    num_query_groups = args.num_query_groups
    
    # Calculate attention dimensions
    head_size = hidden_size // head_num
    heads_per_group = head_num // num_query_groups
    qkv_total_dim = head_num + 2 * num_query_groups

    # Save embedding
    checkpoint['model.embed_tokens.weight'] = model_dict['embedding.word_embeddings.weight']

    # Process each transformer layer
    for layer_idx in range(args.num_layers):
        # Handle QKV weights
        qkv_weights = model_dict[f'decoder.layers.{layer_idx}.self_attention.linear_qkv.weight']
        qkv_weights = qkv_weights.reshape([qkv_total_dim, head_size, hidden_size])

        # Calculate indices for Q, K, V separation
        q_slice = torch.cat([
            torch.arange((heads_per_group + 2) * i, (heads_per_group + 2) * i + heads_per_group)
            for i in range(num_query_groups)
        ])
        k_slice = torch.arange(heads_per_group, qkv_total_dim, (heads_per_group + 2))
        v_slice = torch.arange(heads_per_group + 1, qkv_total_dim, (heads_per_group + 2))

        # Separate and save Q, K, V weights
        checkpoint[f'model.layers.{layer_idx}.self_attn.q_proj.weight'] = qkv_weights[q_slice].reshape(-1, hidden_size)
        checkpoint[f'model.layers.{layer_idx}.self_attn.k_proj.weight'] = qkv_weights[k_slice].reshape(-1, hidden_size)
        checkpoint[f'model.layers.{layer_idx}.self_attn.v_proj.weight'] = qkv_weights[v_slice].reshape(-1, hidden_size)

        # Save attention output projection
        checkpoint[f'model.layers.{layer_idx}.self_attn.o_proj.weight'] = model_dict[f'decoder.layers.{layer_idx}.self_attention.linear_proj.weight']

        # Handle MLP weights
        mlp_weight = model_dict[f'decoder.layers.{layer_idx}.mlp.linear_fc1.weight']
        ffn_hidden_size = mlp_weight.shape[0] // 2
        checkpoint[f'model.layers.{layer_idx}.mlp.gate_proj.weight'] = mlp_weight[:ffn_hidden_size, :]
        checkpoint[f'model.layers.{layer_idx}.mlp.up_proj.weight'] = mlp_weight[ffn_hidden_size:, :]
        checkpoint[f'model.layers.{layer_idx}.mlp.down_proj.weight'] = model_dict[f'decoder.layers.{layer_idx}.mlp.linear_fc2.weight']

        # Save layer norms
        checkpoint[f'model.layers.{layer_idx}.input_layernorm.weight'] = model_dict[f'decoder.layers.{layer_idx}.self_attention.linear_qkv.layer_norm_weight']
        checkpoint[f'model.layers.{layer_idx}.post_attention_layernorm.weight'] = model_dict[f'decoder.layers.{layer_idx}.mlp.linear_fc1.layer_norm_weight']

        print(f"Done layer {layer_idx}")

    # Save final layer norm
    checkpoint['model.norm.weight'] = model_dict['decoder.final_layernorm.weight']
    
    # Handle output layer (weight tying if needed)
    if not args.untie_embeddings_and_output_weights:
        checkpoint['lm_head.weight'] = checkpoint['model.embed_tokens.weight']
    else:
        checkpoint['lm_head.weight'] = model_dict['output_layer.weight']

    return checkpoint

# Convert and load the model
model_dict = checkpoint['model']
hf_dict = convert_megatron_to_hf_dict(model_dict)

Done layer 0
Done layer 1
Done layer 2
Done layer 3
Done layer 4
Done layer 5
Done layer 6
Done layer 7
Done layer 8
Done layer 9
Done layer 10
Done layer 11
Done layer 12
Done layer 13
Done layer 14
Done layer 15


In [12]:
model = AutoModelForCausalLM.from_config(config)

# Load the state dict
# You might need to process the state dict to match transformer's format
model.load_state_dict(hf_dict)

<All keys matched successfully>

In [13]:
model

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 2048)
    (layers): ModuleList(
      (0-15): 16 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=512, bias=False)
          (v_proj): Linear(in_features=2048, out_features=512, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (up_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (down_proj): Linear(in_features=8192, out_features=2048, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((2048,), eps=1e-05)
    (rotary_emb):

In [7]:
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Trainable parameters: {trainable_params:,}")

Trainable parameters: 1,235,814,400


In [6]:
from src.infer.convert_megatron_to_hf import convert_megatron_checkpoint_to_hf

In [8]:
convert_megatron_checkpoint_to_hf(
    checkpoint_path='/iopsstor/scratch/cscs/xyixuan/Megatron-LM/logs/Meg-Runs/Goldfish/llama3-1b-15n-8192sl-60gbsz-goldfish-no-bos/torch/iter_0170005/mp_rank_00/model_optim_rng.pt'
)


Model Arguments:
num_layers........................................                            16
encoder_num_layers................................                            16
decoder_num_layers................................                          None
hidden_size.......................................                          2048
ffn_hidden_size...................................                          8192
num_attention_heads...............................                            32
attention_backend.................................              AttnBackend.auto
kv_channels.......................................                            64
group_query_attention.............................                          True
num_query_groups..................................                             8
max_position_embeddings...........................                          8192
position_embedding_type...........................                          rope
relative_a

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 2048)
    (layers): ModuleList(
      (0-15): 16 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=512, bias=False)
          (v_proj): Linear(in_features=2048, out_features=512, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (up_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (down_proj): Linear(in_features=8192, out_features=2048, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((2048,), eps=1e-05)
    (rotary_emb):

In [21]:
args = checkpoint['args']
for key, value in vars(args).items():
   print(f"{key}: {value}")

num_layers: 16
encoder_num_layers: 16
decoder_num_layers: None
hidden_size: 2048
ffn_hidden_size: 8192
num_attention_heads: 32
attention_backend: AttnBackend.auto
kv_channels: 64
group_query_attention: True
num_query_groups: 8
max_position_embeddings: 8192
position_embedding_type: rope
relative_attention_num_buckets: 32
relative_attention_max_distance: 128
use_rotary_position_embeddings: False
rotary_base: 500000
rotary_percent: 1.0
rotary_interleaved: False
rotary_seq_len_interpolation_factor: None
use_rope_scaling: True
rope_scaling_factor: 32.0
add_position_embedding: True
make_vocab_size_divisible_by: 128
normalization: RMSNorm
norm_epsilon: 1e-05
apply_layernorm_1p: False
apply_residual_connection_post_layernorm: False
openai_gelu: False
squared_relu: False
swiglu: True
onnx_safe: None
bert_binary_head: True
untie_embeddings_and_output_weights: False
multi_latent_attention: False
attention_dropout: 0.0
hidden_dropout: 0.0
weight_decay: 0.1
start_weight_decay: 0.1
end_weight_decay:

In [32]:
args.hidden_size/args.num_attention_heads

64.0

In [6]:
checkpoint['model'].keys()

odict_keys(['embedding.word_embeddings.weight', 'decoder.layers.0.self_attention.core_attention._extra_state', 'decoder.layers.0.self_attention.linear_proj.weight', 'decoder.layers.0.self_attention.linear_proj._extra_state', 'decoder.layers.0.self_attention.linear_qkv.layer_norm_weight', 'decoder.layers.0.self_attention.linear_qkv.weight', 'decoder.layers.0.self_attention.linear_qkv._extra_state', 'decoder.layers.0.mlp.linear_fc1.layer_norm_weight', 'decoder.layers.0.mlp.linear_fc1.weight', 'decoder.layers.0.mlp.linear_fc1._extra_state', 'decoder.layers.0.mlp.linear_fc2.weight', 'decoder.layers.0.mlp.linear_fc2._extra_state', 'decoder.layers.1.self_attention.core_attention._extra_state', 'decoder.layers.1.self_attention.linear_proj.weight', 'decoder.layers.1.self_attention.linear_proj._extra_state', 'decoder.layers.1.self_attention.linear_qkv.layer_norm_weight', 'decoder.layers.1.self_attention.linear_qkv.weight', 'decoder.layers.1.self_attention.linear_qkv._extra_state', 'decoder.laye

In [9]:
# Check hidden size
qkv_weight = checkpoint['model']['decoder.layers.0.self_attention.linear_qkv.weight']
print(f"QKV weight shape: {qkv_weight.shape}")

# Check dtype
print(f"QKV weight dtype: {qkv_weight.dtype}")

# Check embedding size
emb_weight = checkpoint['model']['embedding.word_embeddings.weight']
print(f"Embedding shape: {emb_weight.shape}")

QKV weight shape: torch.Size([3072, 2048])
QKV weight dtype: torch.bfloat16
Embedding shape: torch.Size([128256, 2048])
