In [1]:
from pprint import pprint
from parsers import ModelParser

import torch
from transformers import AutoTokenizer

model_parser = ModelParser([
    "../Meta-Llama-3-8B/model-00001-of-00004.safetensors",
])

In [2]:
pprint(model_parser.tensor_names)

['model.embed_tokens.weight',
 'model.layers.0.input_layernorm.weight',
 'model.layers.0.mlp.down_proj.weight',
 'model.layers.0.mlp.gate_proj.weight',
 'model.layers.0.mlp.up_proj.weight',
 'model.layers.0.post_attention_layernorm.weight',
 'model.layers.0.self_attn.k_proj.weight',
 'model.layers.0.self_attn.o_proj.weight',
 'model.layers.0.self_attn.q_proj.weight',
 'model.layers.0.self_attn.v_proj.weight',
 'model.layers.1.input_layernorm.weight',
 'model.layers.1.mlp.down_proj.weight',
 'model.layers.1.mlp.gate_proj.weight',
 'model.layers.1.mlp.up_proj.weight',
 'model.layers.1.post_attention_layernorm.weight',
 'model.layers.1.self_attn.k_proj.weight',
 'model.layers.1.self_attn.o_proj.weight',
 'model.layers.1.self_attn.q_proj.weight',
 'model.layers.1.self_attn.v_proj.weight',
 'model.layers.2.input_layernorm.weight',
 'model.layers.2.mlp.down_proj.weight',
 'model.layers.2.mlp.gate_proj.weight',
 'model.layers.2.mlp.up_proj.weight',
 'model.layers.2.post_attention_layernorm.we

## Prepare text and embeddings

In [3]:
tokenizer = AutoTokenizer.from_pretrained("../Meta-Llama-3-8B/")

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [4]:
messages = [
    {"role": "system", "content": "You are a pirate chatbot who always responds in pirate speak!"},
    {"role": "user", "content": "Who are you?"},
]

input_ids = tokenizer.apply_chat_template(
    messages,
    add_generation_prompt=True,
    return_tensors="pt"
)#.to(model.device)

terminators = [
    tokenizer.eos_token_id,
    tokenizer.convert_tokens_to_ids("<|eot_id|>")
]


No chat template is defined for this tokenizer - using a default chat template that implements the ChatML format (without BOS/EOS tokens!). If the default is not appropriate for your model, please set `tokenizer.chat_template` to an appropriate template. See https://huggingface.co/docs/transformers/main/chat_templating for more information.



In [5]:
input_ids

tensor([[   27,    91,   318,  5011,    91,    29,  9125,   198,  2675,   527,
           264, 55066,  6369,  6465,   889,  2744, 31680,   304, 55066,  6604,
         88032,    91,   318,  6345,    91,   397,    27,    91,   318,  5011,
            91,    29,   882,   198, 15546,   527,   499, 76514,    91,   318,
          6345,    91,   397,    27,    91,   318,  5011,    91,    29, 78191,
           198]])

## Forward passes

In [3]:
# load testing data
import pickle
with open("intermediate_data_llama3_8b.pkl", "rb") as f:
    data_orig = pickle.load(f)

In [4]:
from transformer_ops import embedding_matrix, RMSNorm, Attention

In [5]:
model_embed_tokens_weight = model_parser.get_tensor('model.embed_tokens.weight')
input_embeddings = embedding_matrix(inputs=input_ids, weights=model_embed_tokens_weight)

  parsed_vals = torch.frombuffer(tensor, dtype=torch.bfloat16).reshape(shape)


NameError: name 'input_ids' is not defined

In [6]:
model_layers_0_input_layernorm = RMSNorm()
model_layers_0_post_attention_layernorm = RMSNorm()

In [7]:
model_layers_0_input_layernorm_weight = model_parser.get_tensor('model.layers.0.input_layernorm.weight')
model_layers_0_post_attention_layernorm_weight = model_parser.get_tensor('model.layers.0.post_attention_layernorm.weight')

In [8]:
input_embeddings = data_orig["inputs"]["model.layers.0.input_layernorm"][0]

In [9]:
torch.allclose(
    model_layers_0_input_layernorm.forward(inputs=data_orig["inputs"]["model.layers.0.input_layernorm"][0], weights=model_layers_0_input_layernorm_weight),
    data_orig["outputs"]["model.layers.0.input_layernorm"]
)

True

In [10]:
torch.allclose(
    model_layers_0_post_attention_layernorm.forward(inputs=data_orig["inputs"]["model.layers.0.post_attention_layernorm"][0], weights=model_layers_0_post_attention_layernorm_weight),
    data_orig["outputs"]["model.layers.0.post_attention_layernorm"]
)

True

## Test Rope and Attention modules

In [11]:
config = {
  "architectures": [
    "LlamaForCausalLM"
  ],
  "attention_bias": False,
  "attention_dropout": 0.0,
  "bos_token_id": 128000,
  "eos_token_id": 128001,
  "hidden_act": "silu",
  "hidden_size": 4096,
  "initializer_range": 0.02,
  "intermediate_size": 14336,
  "max_position_embeddings": 8192,
  "model_type": "llama",
  "num_attention_heads": 32,
  "num_hidden_layers": 32,
  "num_key_value_heads": 8,
  "pretraining_tp": 1,
  "rms_norm_eps": 1e-05,
  "rope_scaling": None,
  "rope_theta": 500000.0,
  "tie_word_embeddings": False,
  "torch_dtype": torch.bfloat16,
  "transformers_version": "4.40.0.dev0",
  "use_cache": True,
  "vocab_size": 128256
}

def precompute_rope_constants(dim: int, end: int, theta: float = 10000.0):
    """
    RoPE:
        - https://blog.eleuther.ai/rotary-embeddings/
        - https://github.com/meta-llama/llama3/blob/main/llama/model.py
    """ 
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    t = torch.arange(end, device=freqs.device, dtype=torch.float32)
    freqs = torch.outer(t, freqs)
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64
    return freqs_cis

In [25]:
freqs_rope = precompute_rope_constants(
    config["hidden_size"] // config["num_attention_heads"],
    config["max_position_embeddings"] * 2,
    config["rope_theta"],
)

In [26]:
freqs_rope.shape

torch.Size([16384, 64])

In [27]:
attn_weights = {
    "self_attn.q_proj.weight": model_parser.get_tensor('model.layers.0.self_attn.q_proj.weight'),
    "self_attn.k_proj.weight": model_parser.get_tensor('model.layers.0.self_attn.k_proj.weight'),
    "self_attn.v_proj.weight": model_parser.get_tensor('model.layers.0.self_attn.v_proj.weight'),
    "self_attn.o_proj.weight": model_parser.get_tensor('model.layers.0.self_attn.o_proj.weight')
}

In [28]:
attention_0 = Attention(config)

In [29]:
# inputs to attention's projections are going to be the same for each projection MLP, and this input also corresponds to the attention block input
attn_block_input_orig = data_orig["inputs"]["model.layers.0.self_attn.q_proj"][0]
attn_block_output_orig = data_orig["outputs"]["model.layers.0.self_attn"][0]

print(attn_block_input_orig.shape)
print(attn_block_input_orig.shape)

torch.Size([1, 51, 4096])
torch.Size([1, 51, 4096])


In [30]:
freqs_rope.shape

torch.Size([16384, 64])

In [31]:
start_pos = 0
seq_len = attn_block_input_orig.shape[1]

mask = None
if seq_len > 1:
    mask = torch.full((seq_len, seq_len), float("-inf"))#, device=tokens.device)

    mask = torch.triu(mask, diagonal=1)

    # When performing key-value caching, we compute the attention scores
    # only for the new sequence. Thus, the matrix of scores is of size
    # (seqlen, cache_len + seqlen), and the only masked entries are (i, j) for
    # j > cache_len + i, since row i corresponds to token cache_len + i.
    mask = torch.hstack(
        #[torch.zeros((seq_len, start_pos), device=tokens.device), mask]
        [torch.zeros((seq_len, start_pos)), mask]
    ).to(torch.bfloat16)#.type_as(h)

In [38]:
attn_block_out_ours = attention_0.forward(attn_block_input_orig, start_pos, weights=attn_weights, mask=mask, freqs_rope=freqs_rope[start_pos:start_pos+seq_len])

In [39]:
attn_block_output_orig

tensor([[[ 7.8125e-03,  1.2665e-03, -1.7395e-03,  ..., -3.9062e-03,
          -9.0332e-03, -3.6469e-03],
         [ 9.2773e-03,  2.9449e-03, -3.6163e-03,  ..., -5.6458e-03,
          -5.9204e-03,  5.2490e-03],
         [-4.6692e-03, -5.0354e-03,  2.9945e-04,  ...,  7.6904e-03,
          -4.6387e-03,  4.3297e-04],
         ...,
         [-3.1738e-03,  5.4169e-04, -8.9722e-03,  ..., -3.0060e-03,
          -3.4180e-03, -2.2430e-03],
         [ 1.7090e-03, -9.2316e-04,  5.2185e-03,  ..., -4.3678e-04,
          -1.1902e-03, -1.2207e-03],
         [ 5.2185e-03,  3.2043e-03, -6.8359e-03,  ..., -4.1504e-03,
          -2.7061e-05, -6.8665e-04]]], dtype=torch.bfloat16)

In [40]:
attn_block_out_ours

tensor([[[ 0.0078,  0.0013, -0.0017,  ..., -0.0039, -0.0090, -0.0036],
         [ 0.0075,  0.0058, -0.0045,  ...,  0.0028, -0.0044,  0.0002],
         [-0.0039, -0.0029, -0.0013,  ...,  0.0105, -0.0037,  0.0004],
         ...,
         [-0.0043, -0.0040, -0.0078,  ...,  0.0007, -0.0041, -0.0003],
         [ 0.0006, -0.0051,  0.0033,  ...,  0.0010, -0.0021, -0.0024],
         [-0.0004, -0.0002,  0.0051,  ...,  0.0070, -0.0038, -0.0054]]],
       dtype=torch.bfloat16)

In [41]:
torch.abs(attn_block_output_orig - attn_block_out_ours)

tensor([[[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
          0.0000e+00, 0.0000e+00],
         [1.7700e-03, 2.8839e-03, 8.6975e-04,  ..., 8.4229e-03,
          1.5259e-03, 5.0659e-03],
         [7.6294e-04, 2.1515e-03, 1.6251e-03,  ..., 2.8076e-03,
          9.1553e-04, 1.9073e-06],
         ...,
         [1.1292e-03, 4.5471e-03, 1.1597e-03,  ..., 3.6621e-03,
          6.4087e-04, 1.9684e-03],
         [1.1292e-03, 4.1504e-03, 1.9073e-03,  ..., 1.4420e-03,
          9.0027e-04, 1.1749e-03],
         [5.6458e-03, 3.4485e-03, 1.1902e-02,  ..., 1.1230e-02,
          3.7842e-03, 4.7607e-03]]], dtype=torch.bfloat16)

In [24]:
torch.allclose(
    attn_block_out_ours, attn_block_output_orig
)

False