In [73]:
from transformers.trainer_utils import set_seed

SEED = 6
set_seed(SEED)

In [74]:
from transformers import AutoConfig, LlamaConfig


# Llama_config = LlamaConfig.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0", num_hidden_layers = 1, use_cache = False, hidden_size = 4, num_attention_heads = 1, 
#                                            output_hidden_states=True,  num_key_value_heads = 1, past_key_values = True)


Llama_config = LlamaConfig.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0", num_hidden_layers = 1, use_cache = False, hidden_size = 8, num_attention_heads = 4, 
                                           output_hidden_states=True,  num_key_value_heads = 2, past_key_values = True)


Llama_config

LlamaConfig {
  "architectures": [
    "LlamaForCausalLM"
  ],
  "attention_bias": false,
  "attention_dropout": 0.0,
  "bos_token_id": 1,
  "eos_token_id": 2,
  "hidden_act": "silu",
  "hidden_size": 8,
  "initializer_range": 0.02,
  "intermediate_size": 5632,
  "max_position_embeddings": 2048,
  "model_type": "llama",
  "num_attention_heads": 4,
  "num_hidden_layers": 1,
  "num_key_value_heads": 2,
  "output_hidden_states": true,
  "pretraining_tp": 1,
  "rms_norm_eps": 1e-05,
  "rope_scaling": null,
  "rope_theta": 10000.0,
  "tie_word_embeddings": false,
  "torch_dtype": "bfloat16",
  "transformers_version": "4.38.0.dev0",
  "use_cache": false,
  "vocab_size": 32000
}

In [75]:
from transformers import AutoModel

tinyllama = AutoModel.from_config(Llama_config)

In [76]:
from transformers import LlamaTokenizer

src_sent = "hi how are"

llama_tokenizer = LlamaTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")

In [77]:
tokenized_src_dict = llama_tokenizer.encode_plus(src_sent, return_tensors='pt')
tokenized_src_dict

{'input_ids': tensor([[   1, 7251,  920,  526]]), 'attention_mask': tensor([[1, 1, 1, 1]])}

In [78]:
src_tokenized = tokenized_src_dict["input_ids"]
src_tokenized

tensor([[   1, 7251,  920,  526]])

In [79]:
llama_tokenizer.decode(*src_tokenized)

'<s> hi how are'

## Verifying the RoPE of Llama2 

### **GROUPED-QUERY ATTENTION**

In [80]:
# head_dim = 4
# num_heads = 1
# seq_len = 4

embed_dim = 8

num_heads = 4

n_heads_q = num_heads

n_kv_heads = 2

seq_len = 4

head_dim = embed_dim // num_heads

n_rep = n_heads_q//n_kv_heads
print("n_rep = ", n_rep)

n_rep =  2


In [81]:
state_dict = tinyllama.state_dict()

In [82]:
import torch

def look_up_table(sentence, vocab_embeds, embedding):

    for i in range(sentence.size(0)):
        for j in range(sentence.size(1)):
            
            # Get the index for the current word token index in the sequence
            word_index = sentence[i, j].item()

            if word_index < 0 or word_index >= vocab_embeds.size(0):
                raise ValueError(f"Invalid word index: {word_index}")

            # Lookup the corresponding embedding vector for the word
            embedding[i, j, :] = vocab_embeds[word_index, :]

            print(f"Word index: {word_index}, Embedding: {vocab_embeds[word_index, :]}")
    print()

    return embedding

In [83]:
def get_embedding_outputs(src_tokens, state_dict, d_model):

    src_vocab_embeds = state_dict["embed_tokens.weight"]

    src_embedding = torch.zeros(src_tokens.size(0), src_tokens.size(1), d_model)
    print("Source sentence embedding")
    src_embedding =  look_up_table(src_tokens, src_vocab_embeds, src_embedding)
    print(src_embedding.shape)


    print("Source embeddings : \n")
    print(src_embedding)

    return src_embedding


input_embeddings = get_embedding_outputs(src_tokenized, state_dict, d_model = embed_dim)

Source sentence embedding
Word index: 1, Embedding: tensor([-0.0258,  0.0210, -0.0483, -0.0079, -0.0164, -0.0140, -0.0164, -0.0006])
Word index: 7251, Embedding: tensor([ 0.0074, -0.0400,  0.0304, -0.0152,  0.0153,  0.0289,  0.0037,  0.0250])
Word index: 920, Embedding: tensor([ 0.0291,  0.0036, -0.0103, -0.0144,  0.0057, -0.0030, -0.0159, -0.0409])
Word index: 526, Embedding: tensor([ 0.0214, -0.0114, -0.0003,  0.0414,  0.0012,  0.0080, -0.0227,  0.0104])

torch.Size([1, 4, 8])
Source embeddings : 

tensor([[[-0.0258,  0.0210, -0.0483, -0.0079, -0.0164, -0.0140, -0.0164,
          -0.0006],
         [ 0.0074, -0.0400,  0.0304, -0.0152,  0.0153,  0.0289,  0.0037,
           0.0250],
         [ 0.0291,  0.0036, -0.0103, -0.0144,  0.0057, -0.0030, -0.0159,
          -0.0409],
         [ 0.0214, -0.0114, -0.0003,  0.0414,  0.0012,  0.0080, -0.0227,
           0.0104]]])


In [84]:
residual = input_embeddings

In [85]:
def apply_layernorm(hidden_states, wt, variance_epsilon = 1e-05):

    dtype =  hidden_states.dtype
    hidden_states = hidden_states.to(torch.float32)
    variance = hidden_states.pow(2).mean(-1, keepdim=True)
    hidden_states = hidden_states * torch.rsqrt(variance + variance_epsilon)
    op = wt * hidden_states
    return op.to(dtype)

In [86]:
hidden_state = apply_layernorm(input_embeddings, state_dict["layers.0.input_layernorm.weight"], variance_epsilon = 1e-05)

hidden_state


tensor([[[-1.1118,  0.9048, -2.0794, -0.3405, -0.7058, -0.6022, -0.7040,
          -0.0255],
         [ 0.3099, -1.6691,  1.2685, -0.6328,  0.6387,  1.2049,  0.1544,
           1.0419],
         [ 1.4492,  0.1812, -0.5150, -0.7174,  0.2862, -0.1514, -0.7926,
          -2.0385],
         [ 1.0946, -0.5810, -0.0146,  2.1127,  0.0634,  0.4081, -1.1568,
           0.5321]]])

In [87]:
state_dict.keys()

odict_keys(['embed_tokens.weight', 'layers.0.self_attn.q_proj.weight', 'layers.0.self_attn.k_proj.weight', 'layers.0.self_attn.v_proj.weight', 'layers.0.self_attn.o_proj.weight', 'layers.0.mlp.gate_proj.weight', 'layers.0.mlp.up_proj.weight', 'layers.0.mlp.down_proj.weight', 'layers.0.input_layernorm.weight', 'layers.0.post_attention_layernorm.weight', 'norm.weight'])

In [88]:
def get_qkv(hidden_state ,Wq, Wk, Wv):


    q_matmul = hidden_state@Wq.T
    k_matmul = hidden_state@Wk.T
    v_matmul = hidden_state@Wv.T

    return q_matmul, k_matmul, v_matmul
    

In [89]:
Wq = state_dict["layers.0.self_attn.q_proj.weight"]
Wk = state_dict["layers.0.self_attn.k_proj.weight"]
Wv = state_dict["layers.0.self_attn.v_proj.weight"]
query, key, value = get_qkv(hidden_state ,Wq, Wk, Wv)

In [90]:
query, key, value

(tensor([[[ 0.0507,  0.0144,  0.0718, -0.0079, -0.0578,  0.0141, -0.1129,
            0.0667],
          [-0.0108,  0.0590, -0.0587,  0.0913,  0.0326, -0.0691,  0.0018,
           -0.0266],
          [-0.0745,  0.0044,  0.0250, -0.1372,  0.0829,  0.0847,  0.0965,
           -0.1339],
          [-0.0155, -0.0212, -0.0371,  0.0443, -0.0387, -0.0826,  0.0557,
            0.0384]]]),
 tensor([[[-0.0798, -0.0299, -0.0730, -0.0716],
          [ 0.0677,  0.0232,  0.0885,  0.0634],
          [-0.0026, -0.0251, -0.0366,  0.0742],
          [-0.0431, -0.0219,  0.0021, -0.0356]]]),
 tensor([[[ 0.0016, -0.0609, -0.0340, -0.0412],
          [ 0.0243,  0.0421, -0.0112, -0.0031],
          [-0.0515,  0.0227, -0.0062,  0.0636],
          [ 0.0180,  0.0050,  0.0095, -0.0304]]]))

In [91]:
def get_sin_cos(dim, seq_len, max_seq_len, base = 10000):

    inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim))

    t = torch.arange(max_seq_len, dtype=torch.int64).type_as(inv_freq)

    freqs = torch.outer(t, inv_freq)
    
    # Uses a different permutation in order to obtain the same calculation
    emb = torch.cat((freqs, freqs), dim=-1) 

    return  emb.cos()[:seq_len], emb.sin()[:seq_len]


def rotate_half(x):

    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]

    # print(x1.shape, x2.shape)
    # x1 = x[ : , : x.shape[-1] // 2]
    # x2 = x[ : , x.shape[-1] // 2 :]

    return torch.cat((-x2, x1), dim=-1)

def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):

    cos = cos.unsqueeze(unsqueeze_dim)
    sin = sin.unsqueeze(unsqueeze_dim)

    # print("HALF ROT SHAPES = ")
    # print((rotate_half(q).shape,   sin.shape))
    # print((rotate_half(k).shape , sin.shape))

    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)

    return q_embed, k_embed

In [92]:
query.shape, key.shape, value.shape

(torch.Size([1, 4, 8]), torch.Size([1, 4, 4]), torch.Size([1, 4, 4]))

In [93]:
bsz, q_len, _ = query.shape

query = query.view(bsz, q_len, n_heads_q, head_dim).transpose(1, 2)
key = key.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)
value = value.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)

In [94]:
query.shape, key.shape, value.shape

(torch.Size([1, 4, 4, 2]), torch.Size([1, 2, 4, 2]), torch.Size([1, 2, 4, 2]))

In [95]:
def rotate_half(x):

    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]

    # print(x1.shape, x2.shape)
    # x1 = x[ : , : x.shape[-1] // 2]
    # x2 = x[ : , x.shape[-1] // 2 :]

    return torch.cat((-x2, x1), dim=-1)

def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):

    cos = cos.unsqueeze(unsqueeze_dim)
    sin = sin.unsqueeze(unsqueeze_dim)

    print(cos.shape, sin.shape)
    print(q.shape, k.shape)

    # print("HALF ROT SHAPES = ")
    # print((rotate_half(q).shape,   sin.shape))
    # print((rotate_half(k).shape , sin.shape))
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)

    return q_embed, k_embed

def get_Rope(query, key, head_dim, seq_len,  num_heads):

    cos, sin = get_sin_cos(dim = head_dim, seq_len = seq_len, max_seq_len = 2048, base = 10000)

    cos = cos.unsqueeze(0)
    sin = sin.unsqueeze(0)

    print(cos.shape, sin.shape)

    q_rotated, k_rotated = apply_rotary_pos_emb(query, key, cos, sin, unsqueeze_dim=1)

    return q_rotated, k_rotated
    

In [96]:
query_rotated, key_rotated = get_Rope(query, key, head_dim, seq_len,  num_heads)

torch.Size([1, 4, 2]) torch.Size([1, 4, 2])
torch.Size([1, 1, 4, 2]) torch.Size([1, 1, 4, 2])
torch.Size([1, 4, 4, 2]) torch.Size([1, 2, 4, 2])


In [97]:
query_rotated, key_rotated

(tensor([[[[ 5.0677e-02,  1.4430e-02],
           [-5.5456e-02,  2.2768e-02],
           [ 2.7001e-02, -6.9541e-02],
           [ 1.8380e-02,  1.8795e-02]],
 
          [[ 7.1790e-02, -7.9177e-03],
           [-1.0855e-01, -2.3097e-06],
           [ 1.1433e-01,  7.9771e-02],
           [ 3.0465e-02, -4.9116e-02]],
 
          [[-5.7788e-02,  1.4150e-02],
           [ 7.5767e-02, -9.9350e-03],
           [-1.1151e-01,  4.0088e-02],
           [ 5.0002e-02,  7.6326e-02]],
 
          [[-1.1291e-01,  6.6661e-02],
           [ 2.3330e-02, -1.2850e-02],
           [ 8.1624e-02,  1.4343e-01],
           [-6.0592e-02, -3.0113e-02]]]]),
 tensor([[[[-0.0798, -0.0299],
           [ 0.0171,  0.0695],
           [ 0.0239,  0.0081],
           [ 0.0458,  0.0156]],
 
          [[-0.0730, -0.0716],
           [-0.0055,  0.1087],
           [-0.0523, -0.0641],
           [ 0.0030,  0.0355]]]]))

In [98]:
value

tensor([[[[ 0.0016, -0.0609],
          [ 0.0243,  0.0421],
          [-0.0515,  0.0227],
          [ 0.0180,  0.0050]],

         [[-0.0340, -0.0412],
          [-0.0112, -0.0031],
          [-0.0062,  0.0636],
          [ 0.0095, -0.0304]]]])

In [99]:
def repeat_kv(x, n_rep):

    bsz, num_key_value_heads, seq_len, head_dim = x.shape

    if n_rep == 1:
        return x
    
    x = x[:, :, None, :, :].expand(bsz, num_key_value_heads, n_rep, seq_len, head_dim)
    return x.reshape(bsz, num_key_value_heads * n_rep, seq_len, head_dim)


In [100]:
key_rotated = repeat_kv(key_rotated, n_rep)
value = repeat_kv(value, n_rep)

In [101]:
# value

In [102]:
query.shape, key_rotated.shape, value.shape

(torch.Size([1, 4, 4, 2]), torch.Size([1, 4, 4, 2]), torch.Size([1, 4, 4, 2]))

In [103]:
import math

def self_attention_rope(query, key, value, attn_mask = None, scale = None, is_causal=False):

    L, S = query.size(-2), key.size(-2)

    scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
    attn_bias = torch.zeros(L, S, dtype=query.dtype)

    if is_causal:
        assert attn_mask is None
        temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
        attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
        attn_bias.to(query.dtype)

    
    if attn_mask is not None:
        if attn_mask.dtype == torch.bool:
            attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
        else:
            attn_bias += attn_mask



    # (bsz, num_heads, tgt_len, head_dim) @ (bsz, num_heads, head_dim, tgt_len) -> (bsz, num_heads, tgt_len, tgt_len) 
    attn_weight = query @ key.transpose(-2, -1) * scale_factor
    attn_weight += attn_bias

    # (bsz, num_heads, tgt_len, tgt_len) 
    attn_weight = torch.softmax(attn_weight, dim=-1)

    print("attn weight = ", attn_weight)


    sum_last_dim = attn_weight.sum(dim=-1)
    tolerance = 1e-6  
    assert torch.allclose(sum_last_dim, torch.ones_like(sum_last_dim), atol=tolerance), "Attention weights sum is not approximately equal to 1"


    # # (bsz, num_heads, tgt_len, tgt_len) @ (bsz, num_heads, tgt_len, head_dim) -> (bsz, num_heads, tgt_len, head_dim) 
    attn_output = attn_weight @ value

    print("ATTEN OUTPUT = ", attn_output)

    return attn_output


In [104]:
self_attn_op = self_attention_rope(query_rotated, key_rotated, value, attn_mask = None, is_causal=True)

attn weight =  tensor([[[[1.0000, 0.0000, 0.0000, 0.0000],
          [0.5005, 0.4995, 0.0000, 0.0000],
          [0.3337, 0.3326, 0.3337, 0.0000],
          [0.2496, 0.2502, 0.2500, 0.2501]],

         [[1.0000, 0.0000, 0.0000, 0.0000],
          [0.5019, 0.4981, 0.0000, 0.0000],
          [0.3307, 0.3351, 0.3342, 0.0000],
          [0.2500, 0.2496, 0.2502, 0.2502]],

         [[1.0000, 0.0000, 0.0000, 0.0000],
          [0.4994, 0.5006, 0.0000, 0.0000],
          [0.3335, 0.3334, 0.3330, 0.0000],
          [0.2486, 0.2517, 0.2489, 0.2508]],

         [[1.0000, 0.0000, 0.0000, 0.0000],
          [0.5001, 0.4999, 0.0000, 0.0000],
          [0.3306, 0.3381, 0.3313, 0.0000],
          [0.2508, 0.2491, 0.2506, 0.2494]]]])
ATTEN OUTPUT =  tensor([[[[ 0.0016, -0.0609],
          [ 0.0129, -0.0094],
          [-0.0086,  0.0013],
          [-0.0019,  0.0023]],

         [[ 0.0016, -0.0609],
          [ 0.0129, -0.0096],
          [-0.0085,  0.0016],
          [-0.0019,  0.0022]],

         [[-

In [105]:
self_attn_op.shape

torch.Size([1, 4, 4, 2])

In [106]:
def out_proj_self_attn(self_attn_op, W, embed_dim):

    self_attn_op = self_attn_op.transpose(1, 2).contiguous()
    self_attn_op = self_attn_op.reshape(bsz, q_len, embed_dim)
    return self_attn_op@W.T

In [107]:
Wo = state_dict["layers.0.self_attn.o_proj.weight"]
print(Wo.shape)

sa_output = out_proj_self_attn(self_attn_op, Wo, embed_dim)

torch.Size([8, 8])


In [108]:
sa_output

tensor([[[ 7.2769e-03, -2.6772e-03,  5.2673e-03, -8.7639e-05,  5.1127e-05,
           2.1528e-03, -5.8406e-03, -1.3345e-03],
         [ 3.0233e-03, -1.5681e-03,  2.1389e-03,  1.0173e-03,  4.9285e-04,
           6.7958e-05, -3.4745e-03, -3.2432e-04],
         [-1.2038e-03,  3.6274e-04,  3.0120e-04, -1.1266e-04, -1.5346e-04,
           3.7689e-04, -2.5342e-04, -1.3245e-03],
         [ 2.0138e-05, -1.6477e-04,  4.0039e-04,  1.5589e-04,  6.1682e-05,
           3.0537e-05, -6.1216e-04, -4.5421e-04]]])

In [109]:
hidden_states = residual + sa_output

hidden_states


tensor([[[-0.0186,  0.0183, -0.0430, -0.0080, -0.0163, -0.0118, -0.0222,
          -0.0019],
         [ 0.0104, -0.0415,  0.0325, -0.0141,  0.0158,  0.0289,  0.0002,
           0.0246],
         [ 0.0279,  0.0040, -0.0100, -0.0145,  0.0056, -0.0027, -0.0162,
          -0.0422],
         [ 0.0215, -0.0115,  0.0001,  0.0415,  0.0013,  0.0080, -0.0233,
           0.0100]]])

In [110]:
residual = hidden_states

In [111]:
hidden_state = apply_layernorm(hidden_states, state_dict["layers.0.post_attention_layernorm.weight"], variance_epsilon = 1e-05)

hidden_state

tensor([[[-0.8769,  0.8670, -2.0344, -0.3781, -0.7726, -0.5595, -1.0491,
          -0.0911],
         [ 0.4236, -1.6847,  1.3187, -0.5734,  0.6403,  1.1730,  0.0090,
           0.9987],
         [ 1.3795,  0.1979, -0.4964, -0.7179,  0.2766, -0.1317, -0.7995,
          -2.0897],
         [ 1.0891, -0.5859,  0.0058,  2.1080,  0.0661,  0.4072, -1.1810,
           0.5059]]])

In [112]:
# Only when 'pretraining_tp == 1'
import torch.nn as nn


def llama_mlp(x, W_down_proj,W_up_proj, W_gate_proj):
    
    up_proj =  x@W_up_proj.T
    gate_proj =  x@W_gate_proj.T

    print("UP PROJ = ", up_proj)
    print()

    print("GATE PROJ = ", gate_proj)
    print()

    temp_proj = up_proj * gate_proj

    silu = nn.SiLU()

    temp_proj = silu(temp_proj)

    print("ACT = ",temp_proj)

    down_proj = temp_proj@W_down_proj.T

    return down_proj

In [113]:
W_down_proj = state_dict["layers.0.mlp.down_proj.weight"]

W_up_proj = state_dict["layers.0.mlp.up_proj.weight"]

W_gate_proj = state_dict["layers.0.mlp.gate_proj.weight"]

In [114]:
hidden_state = llama_mlp(hidden_state, W_down_proj,W_up_proj, W_gate_proj)
hidden_state

UP PROJ =  tensor([[[-0.0796, -0.0497, -0.1191,  ..., -0.0259,  0.0210, -0.0127],
         [ 0.0585,  0.0547,  0.0888,  ...,  0.0021, -0.0128,  0.0241],
         [-0.0024, -0.0741, -0.0656,  ..., -0.1007, -0.0907, -0.0512],
         [ 0.0014,  0.0548,  0.0600,  ..., -0.0253, -0.0084,  0.0511]]])

GATE PROJ =  tensor([[[ 0.0527,  0.0807,  0.1029,  ..., -0.0380, -0.0192,  0.0034],
         [-0.0608, -0.0465, -0.0959,  ..., -0.0405, -0.0408,  0.0139],
         [ 0.0068,  0.0973, -0.0276,  ...,  0.0418,  0.0148,  0.0371],
         [-0.0763, -0.0806, -0.0333,  ...,  0.0304,  0.0270,  0.0448]]])

ACT =  tensor([[[-2.0931e-03, -2.0009e-03, -6.0904e-03,  ...,  4.9118e-04,
          -2.0208e-04, -2.1336e-05],
         [-1.7750e-03, -1.2702e-03, -4.2398e-03,  ..., -4.3319e-05,
           2.6042e-04,  1.6809e-04],
         [-8.1050e-06, -3.5940e-03,  9.0573e-04,  ..., -2.1021e-03,
          -6.6901e-04, -9.4906e-04],
         [-5.2391e-05, -2.2012e-03, -9.9823e-04,  ..., -3.8413e-04,
          -1

tensor([[[-8.4654e-04, -9.3272e-04, -7.8715e-04,  1.3893e-03, -3.4011e-04,
           1.0799e-03,  4.2394e-03,  1.2253e-03],
         [ 2.5295e-03, -1.2564e-03, -3.1189e-03, -1.5183e-03, -1.7109e-04,
           3.5854e-03,  2.0371e-03,  4.6839e-05],
         [-5.8827e-04, -4.1202e-04, -4.8891e-03,  1.3620e-03, -1.8044e-03,
           4.2751e-03,  1.1714e-03, -5.6578e-04],
         [-3.1874e-04,  8.3021e-04,  1.9767e-03,  1.2665e-03,  1.4345e-03,
          -1.0643e-03,  8.1463e-04, -1.9026e-03]]])

In [115]:
hidden_state = hidden_state + residual

In [116]:
hidden_state

tensor([[[-0.0194,  0.0174, -0.0438, -0.0066, -0.0167, -0.0108, -0.0180,
          -0.0007],
         [ 0.0130, -0.0428,  0.0294, -0.0157,  0.0156,  0.0325,  0.0023,
           0.0247],
         [ 0.0273,  0.0036, -0.0149, -0.0131,  0.0038,  0.0016, -0.0150,
          -0.0428],
         [ 0.0211, -0.0107,  0.0021,  0.0428,  0.0027,  0.0070, -0.0225,
           0.0081]]])