In [1]:
from pprint import pprint
from parsers import ModelParser

import torch
#from transformers import AutoTokenizer

model_parser = ModelParser([
    "../Meta-Llama-3-8B/model-00001-of-00004.safetensors",
    "../Meta-Llama-3-8B/model-00002-of-00004.safetensors",
    "../Meta-Llama-3-8B/model-00003-of-00004.safetensors",
    "../Meta-Llama-3-8B/model-00004-of-00004.safetensors",
])

In [2]:
pprint(model_parser.tensor_names)

['model.embed_tokens.weight',
 'model.layers.0.input_layernorm.weight',
 'model.layers.0.mlp.down_proj.weight',
 'model.layers.0.mlp.gate_proj.weight',
 'model.layers.0.mlp.up_proj.weight',
 'model.layers.0.post_attention_layernorm.weight',
 'model.layers.0.self_attn.k_proj.weight',
 'model.layers.0.self_attn.o_proj.weight',
 'model.layers.0.self_attn.q_proj.weight',
 'model.layers.0.self_attn.v_proj.weight',
 'model.layers.1.input_layernorm.weight',
 'model.layers.1.mlp.down_proj.weight',
 'model.layers.1.mlp.gate_proj.weight',
 'model.layers.1.mlp.up_proj.weight',
 'model.layers.1.post_attention_layernorm.weight',
 'model.layers.1.self_attn.k_proj.weight',
 'model.layers.1.self_attn.o_proj.weight',
 'model.layers.1.self_attn.q_proj.weight',
 'model.layers.1.self_attn.v_proj.weight',
 'model.layers.2.input_layernorm.weight',
 'model.layers.2.mlp.down_proj.weight',
 'model.layers.2.mlp.gate_proj.weight',
 'model.layers.2.mlp.up_proj.weight',
 'model.layers.2.post_attention_layernorm.we

## Prepare text and embeddings

In [3]:
tokenizer = AutoTokenizer.from_pretrained("../Meta-Llama-3-8B/")

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [4]:
messages = [
    {"role": "system", "content": "You are a pirate chatbot who always responds in pirate speak!"},
    {"role": "user", "content": "Who are you?"},
]

input_ids = tokenizer.apply_chat_template(
    messages,
    add_generation_prompt=True,
    return_tensors="pt"
)#.to(model.device)

terminators = [
    tokenizer.eos_token_id,
    tokenizer.convert_tokens_to_ids("<|eot_id|>")
]


No chat template is defined for this tokenizer - using a default chat template that implements the ChatML format (without BOS/EOS tokens!). If the default is not appropriate for your model, please set `tokenizer.chat_template` to an appropriate template. See https://huggingface.co/docs/transformers/main/chat_templating for more information.



In [5]:
input_ids

tensor([[   27,    91,   318,  5011,    91,    29,  9125,   198,  2675,   527,
           264, 55066,  6369,  6465,   889,  2744, 31680,   304, 55066,  6604,
         88032,    91,   318,  6345,    91,   397,    27,    91,   318,  5011,
            91,    29,   882,   198, 15546,   527,   499, 76514,    91,   318,
          6345,    91,   397,    27,    91,   318,  5011,    91,    29, 78191,
           198]])

## Forward passes

In [3]:
# load testing data
import pickle
with open("intermediate_data_meta_llama3_8b.pkl", "rb") as f:
    data_orig = pickle.load(f)

In [4]:
from transformer_ops import embedding_matrix, RMSNorm, Attention, FFN, remap_weights_if_needed

In [5]:
model_embed_tokens_weight = model_parser.get_tensor('model.embed_tokens.weight')
input_embeddings = embedding_matrix(inputs=input_ids, weights=model_embed_tokens_weight)

  parsed_vals = torch.frombuffer(tensor, dtype=torch.bfloat16).reshape(shape)


NameError: name 'input_ids' is not defined

In [5]:
model_layers_0_input_layernorm = RMSNorm()
model_layers_0_post_attention_layernorm = RMSNorm()

In [6]:
model_layers_0_input_layernorm_weight = model_parser.get_tensor('model.layers.0.input_layernorm.weight')
model_layers_0_post_attention_layernorm_weight = model_parser.get_tensor('model.layers.0.post_attention_layernorm.weight')

In [7]:
torch.allclose(
    model_layers_0_input_layernorm.forward(inputs=data_orig["inputs"]["layers.0.attention_norm"][0].cpu(), weights=model_layers_0_input_layernorm_weight),
    data_orig["outputs"]["layers.0.attention_norm"].cpu()
)

True

In [8]:
torch.allclose(
    model_layers_0_post_attention_layernorm.forward(inputs=data_orig["inputs"]["layers.0.ffn_norm"][0].cpu(), weights=model_layers_0_post_attention_layernorm_weight),
    data_orig["outputs"]["layers.0.ffn_norm"].cpu()
)

True

## Test Rope and Attention modules

In [5]:
config = {
  "architectures": [
    "LlamaForCausalLM"
  ],
  "attention_bias": False,
  "attention_dropout": 0.0,
  "bos_token_id": 128000,
  "eos_token_id": 128001,
  "hidden_act": "silu",
  "hidden_size": 4096,
  "initializer_range": 0.02,
  "intermediate_size": 14336,
  "max_position_embeddings": 8192,
  "model_type": "llama",
  "num_attention_heads": 32,
  "num_hidden_layers": 32,
  "num_key_value_heads": 8,
  "pretraining_tp": 1,
  "rms_norm_eps": 1e-05,
  "rope_scaling": None,
  "rope_theta": 500000.0,
  "tie_word_embeddings": False,
  "torch_dtype": torch.bfloat16,
  "transformers_version": "4.40.0.dev0",
  "use_cache": True,
  "vocab_size": 128256
}

def precompute_rope_constants(dim: int, end: int, theta: float = 10000.0):
    """
    RoPE:
        - https://blog.eleuther.ai/rotary-embeddings/
        - https://github.com/meta-llama/llama3/blob/main/llama/model.py
    """ 
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    t = torch.arange(end, device=freqs.device, dtype=torch.float32)
    freqs = torch.outer(t, freqs)
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64
    return freqs_cis

In [6]:
freqs_rope = precompute_rope_constants(
    config["hidden_size"] // config["num_attention_heads"],
    config["max_position_embeddings"] * 2,
    config["rope_theta"],
)

In [7]:
freqs_rope.shape

torch.Size([16384, 64])

In [8]:
# Loads all weights for an attention module, renaming and remapping weights if needed
layer_idx = 0
attn_weights = {}
for letter in ['q', 'k', 'v', 'o']:
    orig_key = f'model.layers.{layer_idx}.self_attn.{letter}_proj.weight'
    new_key = f'self_attn.{letter}_proj.weight'
    attn_weights[new_key] = remap_weights_if_needed(model_parser.get_tensor(orig_key), param_name=orig_key, config=config)

In [9]:
attention_0 = Attention(config, max_seq_len=128)

In [10]:
# inputs to attention's projections are going to be the same for each projection MLP, and this input also corresponds to the attention block input
attn_block_input_orig = data_orig["inputs"]["layers.0.attention"][0].cpu()
attn_block_output_orig = data_orig["outputs"]["layers.0.attention.wo"].cpu()

print(attn_block_input_orig.shape)
print(attn_block_input_orig.shape)

torch.Size([1, 72, 4096])
torch.Size([1, 72, 4096])


In [11]:
start_pos = 0
seq_len = attn_block_input_orig.shape[1]

mask = None
if seq_len > 1:
    mask = torch.full((seq_len, seq_len), float("-inf"))#, device=tokens.device)

    mask = torch.triu(mask, diagonal=1)

    # When performing key-value caching, we compute the attention scores
    # only for the new sequence. Thus, the matrix of scores is of size
    # (seqlen, cache_len + seqlen), and the only masked entries are (i, j) for
    # j > cache_len + i, since row i corresponds to token cache_len + i.
    mask = torch.hstack(
        #[torch.zeros((seq_len, start_pos), device=tokens.device), mask]
        [torch.zeros((seq_len, start_pos)), mask]
    ).to(torch.bfloat16)#.type_as(h)

In [12]:
torch.allclose(
    data_orig["inputs"]["layers.0.attention"][2].cpu(), freqs_rope[start_pos:start_pos+seq_len]
)

True

In [13]:
torch.allclose(
    data_orig["inputs"]["layers.0.attention"][2].cpu(), freqs_rope[start_pos:start_pos+seq_len]
)

True

In [14]:
attn_block_out_ours = attention_0.forward(attn_block_input_orig, start_pos, weights=attn_weights, mask=mask, freqs_rope=freqs_rope[start_pos:start_pos+seq_len])

In [15]:
torch.allclose(
    attn_block_out_ours, attn_block_output_orig
)

False

In [16]:
torch.allclose(
    attn_block_input_orig,
    data_orig["inputs"]["layers.0.attention.wq"][0].cpu()
)

True

In [17]:
torch.allclose(
    attn_block_input_orig,
    data_orig["inputs"]["layers.0.attention.wk"][0].cpu()
)

True

In [18]:
torch.allclose(
    data_orig["outputs"]["layers.0.attention.wq"].cuda(),
    torch.nn.functional.linear(data_orig["inputs"]["layers.0.attention.wq"][0].cuda(), attn_weights["self_attn.q_proj.weight"].cuda())
)

True

In [19]:
torch.allclose(
    data_orig["outputs"]["layers.0.attention.wk"].cuda(),
    torch.nn.functional.linear(data_orig["inputs"]["layers.0.attention.wk"][0].cuda(), attn_weights["self_attn.k_proj.weight"].cuda())
)

True

In [20]:
torch.allclose(
    data_orig["outputs"]["layers.0.attention.wv"].cuda(),
    torch.nn.functional.linear(data_orig["inputs"]["layers.0.attention.wv"][0].cuda(), attn_weights["self_attn.v_proj.weight"].cuda())
)

True

In [21]:
torch.allclose(
    data_orig["outputs"]["layers.0.attention.wo"].cuda(),
    torch.nn.functional.linear(data_orig["inputs"]["layers.0.attention.wo"][0].cuda(), attn_weights["self_attn.o_proj.weight"].cuda())
)

True

## FFN

In [22]:
ffn = FFN()

In [23]:
ffn_weights = {
    "mlp.down_proj.weight": model_parser.get_tensor('model.layers.0.mlp.down_proj.weight'),
    "mlp.gate_proj.weight": model_parser.get_tensor('model.layers.0.mlp.gate_proj.weight'),
    "mlp.up_proj.weight": model_parser.get_tensor('model.layers.0.mlp.up_proj.weight'),
}

In [24]:
ffn_input_orig = data_orig["inputs"]["layers.0.feed_forward"][0].cpu()
ffn_output_orig = data_orig["outputs"]["layers.0.feed_forward"].cpu()

print(attn_block_input_orig.shape)
print(attn_block_input_orig.shape)

torch.Size([1, 72, 4096])
torch.Size([1, 72, 4096])


In [25]:
ffn_input_orig.shape

torch.Size([1, 72, 4096])

In [26]:
ffn_output_ours = ffn.forward(ffn_input_orig, ffn_weights)

In [33]:
torch.allclose(
    ffn_output_ours,
    ffn_output_orig
)

False

In [34]:
torch.mean(
    torch.abs(
    ffn_output_ours - ffn_output_orig
    )
)

tensor(4.7497e-08, dtype=torch.bfloat16)

In [30]:
ffn_output_ours

tensor([[[ 1.0109e-04,  4.7607e-03, -5.7983e-03,  ...,  1.0010e-02,
           4.0894e-03, -5.3711e-03],
         [ 8.1787e-03,  2.7588e-02,  6.7139e-04,  ..., -4.6387e-02,
          -2.4780e-02, -1.2451e-02],
         [ 2.5787e-03,  1.1292e-03,  1.3611e-02,  ...,  3.0884e-02,
          -1.7334e-02, -1.4709e-02],
         ...,
         [ 1.6689e-05,  7.6675e-04, -1.2283e-03,  ..., -6.7139e-04,
          -7.6675e-04, -1.8883e-04],
         [ 1.3530e-05,  7.4768e-04, -1.1902e-03,  ..., -6.5994e-04,
          -7.4768e-04, -1.8501e-04],
         [ 1.2279e-05,  7.3242e-04, -1.1673e-03,  ..., -6.4468e-04,
          -7.2861e-04, -1.8215e-04]]], dtype=torch.bfloat16)

In [31]:
ffn_output_orig

tensor([[[ 1.0109e-04,  4.7607e-03, -5.7983e-03,  ...,  1.0010e-02,
           4.0894e-03, -5.3711e-03],
         [ 8.1787e-03,  2.7588e-02,  6.7139e-04,  ..., -4.6387e-02,
          -2.4780e-02, -1.2451e-02],
         [ 2.5787e-03,  1.1292e-03,  1.3611e-02,  ...,  3.0884e-02,
          -1.7334e-02, -1.4709e-02],
         ...,
         [ 1.6689e-05,  7.6675e-04, -1.2283e-03,  ..., -6.7139e-04,
          -7.6675e-04, -1.8883e-04],
         [ 1.3530e-05,  7.4768e-04, -1.1902e-03,  ..., -6.5994e-04,
          -7.4768e-04, -1.8501e-04],
         [ 1.2159e-05,  7.3242e-04, -1.1673e-03,  ..., -6.4468e-04,
          -7.2861e-04, -1.8215e-04]]], dtype=torch.bfloat16)