In [8]:
import json

import torch
from safetensors.torch import load_file

LOAD_PATH = "/Users/blacksamorez/models/Meta-Llama-3.1-8B-Instruct-AQLM-PV-2Bit-2x8-hf/"
SAVE_PATH = "/Users/blacksamorez/models/Meta-Llama-3.1-8B-Instruct-AQLM-PV-2Bit-2x8-hf/"


with open(LOAD_PATH + "config.json", "r") as file:
    hf_config = json.load(file)
    
with open(LOAD_PATH + "params.json", "w") as file:
    json.dump(
        {
            "dim": hf_config["hidden_size"],
            "multiple_of": 256,
            "n_heads": hf_config["num_attention_heads"],
            "n_kv_heads": hf_config["num_key_value_heads"],
            "n_layers": hf_config["num_hidden_layers"],
            "norm_eps": hf_config["rms_norm_eps"],
            "vocab_size": hf_config["vocab_size"],
            "ffn_dim_multiplier":  hf_config['intermediate_size'] / (8 / 3 * hf_config['hidden_size']),
        },
        file,
    )

hidden_size = hf_config["hidden_size"]
head_dim = hf_config["hidden_size"] // hf_config["num_attention_heads"]

dict = load_file(LOAD_PATH + "model.safetensors")

mapping = {
    "model.": "",
    
    "self_attn.q_proj": "attention.wq",
    "self_attn.k_proj": "attention.wk",
    "self_attn.v_proj": "attention.wv",
    "self_attn.o_proj": "attention.wo",
    
    "mlp.up_proj": "feed_forward.w3",
    "mlp.gate_proj": "feed_forward.w1",
    "mlp.down_proj": "feed_forward.w2",
    
    "input_layernorm": "attention_norm",
    "post_attention_layernorm": "ffn_norm",
    
    "lm_head": "output",
    "embed_tokens": "tok_embeddings",
}


new_dict = {}


for key, value in dict.items():
    for old, new in mapping.items():
        key = key.replace(old, new)
        
    if "attention.wq.codes" in key or "attention.wk.codes" in key:
        # [num_out_groups, num_in_groups, num_codebooks]
        value = (value.reshape(-1, 2, head_dim // 2, hidden_size // 8, 2)
            .transpose(1, 2)
            .reshape(-1, hidden_size // 8, 2))
    
        
    if "attention.wq.scales" in key or "attention.wk.scales" in key:
        # [num_out_groups, 1, 1, 1]
        value = (value.reshape(-1, 2, head_dim // 2, 1)
            .transpose(1, 2)
            .reshape(-1, 1, 1, 1))
        
    if "codes" in key:
        value = value.transpose(0, 1) # <- Special memory layout for lut kernels
    
    if value.dtype == torch.float16:
        value = value.float()
    
    print(f"{key}: {value.shape=}")
    new_dict[key] = value

torch.save(new_dict, SAVE_PATH + "model.pth")

output.weight: value.shape=torch.Size([128256, 4096])
tok_embeddings.weight: value.shape=torch.Size([128256, 4096])
layers.0.attention_norm.weight: value.shape=torch.Size([4096])
layers.0.feed_forward.w2.codebooks: value.shape=torch.Size([2, 256, 1, 8])
layers.0.feed_forward.w2.codes: value.shape=torch.Size([1792, 4096, 2])
layers.0.feed_forward.w2.scales: value.shape=torch.Size([4096, 1, 1, 1])
layers.0.feed_forward.w1.codebooks: value.shape=torch.Size([2, 256, 1, 8])
layers.0.feed_forward.w1.codes: value.shape=torch.Size([512, 14336, 2])
layers.0.feed_forward.w1.scales: value.shape=torch.Size([14336, 1, 1, 1])
layers.0.feed_forward.w3.codebooks: value.shape=torch.Size([2, 256, 1, 8])
layers.0.feed_forward.w3.codes: value.shape=torch.Size([512, 14336, 2])
layers.0.feed_forward.w3.scales: value.shape=torch.Size([14336, 1, 1, 1])
layers.0.ffn_norm.weight: value.shape=torch.Size([4096])
layers.0.attention.wk.codebooks: value.shape=torch.Size([2, 256, 1, 8])
layers.0.attention.wk.codes: v