In [1]:
from transformers.trainer_utils import set_seed

SEED = 6
set_seed(SEED)

In [2]:
from transformers import AutoConfig, LlamaConfig


# Llama_config = LlamaConfig.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0", num_hidden_layers = 1, use_cache = False, hidden_size = 4, num_attention_heads = 1, 
#                                            output_hidden_states=True,  num_key_value_heads = 1, past_key_values = True)


Llama_config = LlamaConfig.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0", num_hidden_layers = 1, use_cache = False, hidden_size = 8, num_attention_heads = 2, 
                                           output_hidden_states=True,  num_key_value_heads = 2, past_key_values = True)


Llama_config

LlamaConfig {
  "architectures": [
    "LlamaForCausalLM"
  ],
  "attention_bias": false,
  "attention_dropout": 0.0,
  "bos_token_id": 1,
  "eos_token_id": 2,
  "hidden_act": "silu",
  "hidden_size": 8,
  "initializer_range": 0.02,
  "intermediate_size": 5632,
  "max_position_embeddings": 2048,
  "model_type": "llama",
  "num_attention_heads": 2,
  "num_hidden_layers": 1,
  "num_key_value_heads": 2,
  "output_hidden_states": true,
  "pretraining_tp": 1,
  "rms_norm_eps": 1e-05,
  "rope_scaling": null,
  "rope_theta": 10000.0,
  "tie_word_embeddings": false,
  "torch_dtype": "bfloat16",
  "transformers_version": "4.38.0.dev0",
  "use_cache": false,
  "vocab_size": 32000
}

In [3]:
from transformers import AutoModel

tinyllama = AutoModel.from_config(Llama_config)

In [4]:
from transformers import LlamaTokenizer

src_sent = "hi how are"

llama_tokenizer = LlamaTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")

In [5]:
tokenized_src_dict = llama_tokenizer.encode_plus(src_sent, return_tensors='pt')
tokenized_src_dict

{'input_ids': tensor([[   1, 7251,  920,  526]]), 'attention_mask': tensor([[1, 1, 1, 1]])}

In [6]:
src_tokenized = tokenized_src_dict["input_ids"]
src_tokenized

tensor([[   1, 7251,  920,  526]])

In [7]:
llama_tokenizer.decode(*src_tokenized)

'<s> hi how are'

In [8]:
output = tinyllama(**tokenized_src_dict)

###########################################################################
LLAMA DECODER FWD START

Attention mask =  None

Input (hidden states) =  tensor([[[-0.0334, -0.0280, -0.0287,  0.0086, -0.0090, -0.0133,  0.0122,
          -0.0035],
         [ 0.0080,  0.0184,  0.0083,  0.0189,  0.0256, -0.0133, -0.0215,
          -0.0009],
         [ 0.0525,  0.0237, -0.0210, -0.0215,  0.0180,  0.0015,  0.0290,
          -0.0205],
         [-0.0094, -0.0067, -0.0051, -0.0405,  0.0052, -0.0106, -0.0160,
           0.0159]]], grad_fn=<EmbeddingBackward0>)

LayerNorm(hidden states) =  tensor([[[-1.6458, -1.3806, -1.4139,  0.4244, -0.4413, -0.6542,  0.5997,
          -0.1725],
         [ 0.4819,  1.1102,  0.4978,  1.1398,  1.5423, -0.8012, -1.2943,
          -0.0517],
         [ 1.9347,  0.8732, -0.7735, -0.7947,  0.6625,  0.0553,  1.0699,
          -0.7544],
         [-0.5286, -0.3752, -0.2874, -2.2779,  0.2938, -0.5967, -0.9011,
           0.8947]]], grad_fn=<MulBackward0>)

position_ids =  te

## Verifying the RoPE of Llama2 

ONLY for MULTI-HEAD ATTENTION

In [9]:
# head_dim = 4
# num_heads = 1
# seq_len = 4

embed_dim = 8
num_heads = 2
seq_len = 4

head_dim = embed_dim // num_heads

In [10]:
state_dict = tinyllama.state_dict()

In [11]:
import torch

def look_up_table(sentence, vocab_embeds, embedding):

    for i in range(sentence.size(0)):
        for j in range(sentence.size(1)):
            
            # Get the index for the current word token index in the sequence
            word_index = sentence[i, j].item()

            if word_index < 0 or word_index >= vocab_embeds.size(0):
                raise ValueError(f"Invalid word index: {word_index}")

            # Lookup the corresponding embedding vector for the word
            embedding[i, j, :] = vocab_embeds[word_index, :]

            print(f"Word index: {word_index}, Embedding: {vocab_embeds[word_index, :]}")
    print()

    return embedding

In [12]:
def get_embedding_outputs(src_tokens, state_dict, d_model):

    src_vocab_embeds = state_dict["embed_tokens.weight"]

    src_embedding = torch.zeros(src_tokens.size(0), src_tokens.size(1), d_model)
    print("Source sentence embedding")
    src_embedding =  look_up_table(src_tokens, src_vocab_embeds, src_embedding)
    print(src_embedding.shape)


    print("Source embeddings : \n")
    print(src_embedding)

    return src_embedding


input_embeddings = get_embedding_outputs(src_tokenized, state_dict, d_model = embed_dim)

Source sentence embedding
Word index: 1, Embedding: tensor([-0.0334, -0.0280, -0.0287,  0.0086, -0.0090, -0.0133,  0.0122, -0.0035])
Word index: 7251, Embedding: tensor([ 0.0080,  0.0184,  0.0083,  0.0189,  0.0256, -0.0133, -0.0215, -0.0009])
Word index: 920, Embedding: tensor([ 0.0525,  0.0237, -0.0210, -0.0215,  0.0180,  0.0015,  0.0290, -0.0205])
Word index: 526, Embedding: tensor([-0.0094, -0.0067, -0.0051, -0.0405,  0.0052, -0.0106, -0.0160,  0.0159])

torch.Size([1, 4, 8])
Source embeddings : 

tensor([[[-0.0334, -0.0280, -0.0287,  0.0086, -0.0090, -0.0133,  0.0122,
          -0.0035],
         [ 0.0080,  0.0184,  0.0083,  0.0189,  0.0256, -0.0133, -0.0215,
          -0.0009],
         [ 0.0525,  0.0237, -0.0210, -0.0215,  0.0180,  0.0015,  0.0290,
          -0.0205],
         [-0.0094, -0.0067, -0.0051, -0.0405,  0.0052, -0.0106, -0.0160,
           0.0159]]])


In [13]:
def apply_layernorm(hidden_states, wt, variance_epsilon = 1e-05):

    dtype =  hidden_states.dtype
    hidden_states = hidden_states.to(torch.float32)
    variance = hidden_states.pow(2).mean(-1, keepdim=True)
    hidden_states = hidden_states * torch.rsqrt(variance + variance_epsilon)
    op = wt * hidden_states
    return op.to(dtype)

In [14]:
hidden_state = apply_layernorm(input_embeddings, state_dict["layers.0.input_layernorm.weight"], variance_epsilon = 1e-05)

residual = hidden_state

hidden_state


tensor([[[-1.6458, -1.3806, -1.4139,  0.4244, -0.4413, -0.6542,  0.5997,
          -0.1725],
         [ 0.4819,  1.1102,  0.4978,  1.1398,  1.5423, -0.8012, -1.2943,
          -0.0517],
         [ 1.9347,  0.8732, -0.7735, -0.7947,  0.6625,  0.0553,  1.0699,
          -0.7544],
         [-0.5286, -0.3752, -0.2874, -2.2779,  0.2938, -0.5967, -0.9011,
           0.8947]]])

In [15]:
state_dict.keys()

odict_keys(['embed_tokens.weight', 'layers.0.self_attn.q_proj.weight', 'layers.0.self_attn.k_proj.weight', 'layers.0.self_attn.v_proj.weight', 'layers.0.self_attn.o_proj.weight', 'layers.0.mlp.gate_proj.weight', 'layers.0.mlp.up_proj.weight', 'layers.0.mlp.down_proj.weight', 'layers.0.input_layernorm.weight', 'layers.0.post_attention_layernorm.weight', 'norm.weight'])

In [16]:
def get_qkv(hidden_state ,Wq, Wk, Wv):


    q_matmul = hidden_state@Wq.T
    k_matmul = hidden_state@Wk.T
    v_matmul = hidden_state@Wv.T

    return q_matmul, k_matmul, v_matmul
    

In [17]:
Wq = state_dict["layers.0.self_attn.q_proj.weight"]
Wk = state_dict["layers.0.self_attn.k_proj.weight"]
Wv = state_dict["layers.0.self_attn.v_proj.weight"]
query, key, value = get_qkv(hidden_state ,Wq, Wk, Wv)

In [18]:
query, key, value

(tensor([[[-0.0394, -0.0600, -0.0235, -0.1077, -0.0176, -0.0679,  0.0625,
            0.0028],
          [ 0.0181, -0.0323,  0.0318, -0.0450, -0.0237,  0.0436, -0.0671,
           -0.0137],
          [ 0.0511,  0.0496, -0.0387,  0.1762,  0.0113,  0.0426, -0.0325,
            0.0862],
          [ 0.0717, -0.0031,  0.0471,  0.0010, -0.0151,  0.0288, -0.1334,
           -0.0413]]]),
 tensor([[[ 0.0698, -0.0077, -0.0222, -0.0392, -0.0634,  0.0465,  0.0072,
           -0.0172],
          [-0.0269, -0.0502, -0.0530, -0.0392,  0.0087, -0.1056, -0.0026,
            0.1114],
          [ 0.0329, -0.0071,  0.0435,  0.0644,  0.0375, -0.0744,  0.0895,
            0.0708],
          [ 0.0050,  0.0169,  0.0753, -0.1054,  0.0291,  0.0531,  0.0078,
           -0.0332]]]),
 tensor([[[ 0.0317, -0.0136,  0.0647,  0.0628, -0.0717, -0.0558,  0.0196,
            0.0803],
          [ 0.0189, -0.0083, -0.0210,  0.0649,  0.0463,  0.0822, -0.0375,
           -0.0235],
          [ 0.0314,  0.0382, -0.1098, -0.141

In [19]:
def get_sin_cos(dim, seq_len, max_seq_len, base = 10000):

    inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim))

    t = torch.arange(max_seq_len, dtype=torch.int64).type_as(inv_freq)

    freqs = torch.outer(t, inv_freq)
    
    # Uses a different permutation in order to obtain the same calculation
    emb = torch.cat((freqs, freqs), dim=-1)    

    return  emb.cos()[:seq_len], emb.sin()[:seq_len]


def rotate_half(x):

    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]

    # print(x1.shape, x2.shape)
    # x1 = x[ : , : x.shape[-1] // 2]
    # x2 = x[ : , x.shape[-1] // 2 :]

    return torch.cat((-x2, x1), dim=-1)

def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):

    cos = cos.unsqueeze(unsqueeze_dim)
    sin = sin.unsqueeze(unsqueeze_dim)

    # print("HLAF ROT SHAPES = ")
    # print((rotate_half(q).shape,   sin.shape))
    # print((rotate_half(k).shape , sin.shape))

    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)

    return q_embed, k_embed

In [20]:
bsz, q_len, _ = query.shape

query = query.view(bsz, q_len, num_heads, head_dim).transpose(1, 2)
key = key.view(bsz, q_len, num_heads, head_dim).transpose(1, 2)
value = value.view(bsz, q_len, num_heads, head_dim).transpose(1, 2)

In [21]:
def get_Rope(query, key, head_dim, seq_len,  num_heads):

    cos, sin = get_sin_cos(dim = head_dim, seq_len = seq_len, max_seq_len = 2048, base = 10000)

    cos = cos.unsqueeze(0)
    sin = sin.unsqueeze(0)

    q_rotated, k_rotated = apply_rotary_pos_emb(query, key, cos, sin, unsqueeze_dim=1)

    return q_rotated, k_rotated
    

In [22]:
query_rotated, key_rotated = get_Rope(query, key, head_dim, seq_len,  num_heads)

In [23]:
query_rotated, key_rotated

(tensor([[[[-0.0394, -0.0600, -0.0235, -0.1077],
           [-0.0169, -0.0319,  0.0324, -0.0453],
           [ 0.0139,  0.0461,  0.0626,  0.1771],
           [-0.0776, -0.0031, -0.0365,  0.0009]],
 
          [[-0.0176, -0.0679,  0.0625,  0.0028],
           [ 0.0437,  0.0437, -0.0562, -0.0133],
           [ 0.0249,  0.0408,  0.0238,  0.0870],
           [ 0.0338,  0.0300,  0.1299, -0.0404]]]]),
 tensor([[[[ 0.0698, -0.0077, -0.0222, -0.0392],
           [ 0.0300, -0.0498, -0.0513, -0.0397],
           [-0.0532, -0.0084,  0.0118,  0.0642],
           [-0.0156,  0.0201, -0.0738, -0.1049]],
 
          [[-0.0634,  0.0465,  0.0072, -0.0172],
           [ 0.0069, -0.1067,  0.0059,  0.1103],
           [-0.0970, -0.0758, -0.0032,  0.0693],
           [-0.0299,  0.0541, -0.0036, -0.0316]]]]))

In [24]:
value

tensor([[[[ 0.0317, -0.0136,  0.0647,  0.0628],
          [ 0.0189, -0.0083, -0.0210,  0.0649],
          [ 0.0314,  0.0382, -0.1098, -0.1415],
          [ 0.0395,  0.1032,  0.0332, -0.0572]],

         [[-0.0717, -0.0558,  0.0196,  0.0803],
          [ 0.0463,  0.0822, -0.0375, -0.0235],
          [ 0.1071, -0.0132,  0.0508,  0.0340],
          [ 0.0280, -0.0223,  0.0267,  0.0051]]]])

In [25]:
import math

def self_attention_rope(query, key, value, attn_mask = None, scale = None, is_causal=False):

    L, S = query.size(-2), key.size(-2)

    scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
    attn_bias = torch.zeros(L, S, dtype=query.dtype)

    if is_causal:
        assert attn_mask is None
        temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
        attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
        attn_bias.to(query.dtype)

    
    if attn_mask is not None:
        if attn_mask.dtype == torch.bool:
            attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
        else:
            attn_bias += attn_mask



    # (bsz, num_heads, tgt_len, head_dim) @ (bsz, num_heads, head_dim, tgt_len) -> (bsz, num_heads, tgt_len, tgt_len) 
    attn_weight = query @ key.transpose(-2, -1) * scale_factor
    attn_weight += attn_bias

    # (bsz, num_heads, tgt_len, tgt_len) 
    attn_weight = torch.softmax(attn_weight, dim=-1)


    sum_last_dim = attn_weight.sum(dim=-1)
    tolerance = 1e-6  
    assert torch.allclose(sum_last_dim, torch.ones_like(sum_last_dim), atol=tolerance), "Attention weights sum is not approximately equal to 1"


    # # (bsz, num_heads, tgt_len, tgt_len) @ (bsz, num_heads, tgt_len, head_dim) -> (bsz, num_heads, tgt_len, head_dim) 
    attn_output = attn_weight @ value

    print("ATTEN OUTPUT = ", attn_output)

    return attn_output


In [26]:
self_attn_op = self_attention_rope(query_rotated, key_rotated, value, attn_mask = None, is_causal=True)

ATTEN OUTPUT =  tensor([[[[ 0.0317, -0.0136,  0.0647,  0.0628],
          [ 0.0253, -0.0110,  0.0219,  0.0638],
          [ 0.0274,  0.0055, -0.0223, -0.0051],
          [ 0.0304,  0.0299, -0.0083, -0.0179]],

         [[-0.0717, -0.0558,  0.0196,  0.0803],
          [-0.0128,  0.0131, -0.0089,  0.0285],
          [ 0.0273,  0.0045,  0.0109,  0.0302],
          [ 0.0273, -0.0023,  0.0149,  0.0240]]]])


In [27]:
self_attn_op.shape

torch.Size([1, 2, 4, 4])

In [28]:
def out_proj_self_attn(self_attn_op, W, embed_dim):

    self_attn_op = self_attn_op.transpose(1, 2).contiguous()
    self_attn_op = self_attn_op.reshape(bsz, q_len, embed_dim)
    return self_attn_op@W.T

In [29]:
Wo = state_dict["layers.0.self_attn.o_proj.weight"]
print(Wo.shape)

sa_output = out_proj_self_attn(self_attn_op, Wo, embed_dim)

torch.Size([8, 8])


In [30]:
sa_output

tensor([[[ 5.3733e-04, -1.7497e-05,  3.1351e-04,  3.3713e-03, -3.0710e-03,
           3.1266e-03,  3.8834e-03, -1.0336e-03],
         [ 1.8780e-04,  9.9296e-04,  2.1365e-03,  1.2946e-04, -2.4825e-03,
           9.5584e-04,  1.0146e-03, -5.4013e-04],
         [ 2.2841e-03,  5.6324e-05,  1.6864e-03, -6.4167e-04,  4.1008e-04,
           2.3058e-03,  1.9699e-03, -9.3434e-04],
         [ 2.0368e-03, -1.0943e-03,  7.3408e-04, -1.7564e-04,  1.0480e-03,
           2.5588e-03,  2.1456e-03, -1.2697e-04]]])