In [1]:
from transformers.trainer_utils import set_seed
import torch
SEED = 6
set_seed(SEED)

In [2]:
from transformers import AutoConfig, MistralConfig


mistral_config = MistralConfig.from_pretrained("openaccess-ai-collective/tiny-mistral", num_hidden_layers = 1, use_cache = False, hidden_size = 8, num_attention_heads = 4, 
                                           output_hidden_states=True,  num_key_value_heads = 2, past_key_values = True, intermediate_size = 8, sliding_window = 3, dropout_p = 0)


mistral_config

MistralConfig {
  "_name_or_path": "./tiny-mistral",
  "architectures": [
    "MistralForCausalLM"
  ],
  "attention_dropout": 0.0,
  "bos_token_id": 1,
  "dropout_p": 0,
  "eos_token_id": 2,
  "hidden_act": "silu",
  "hidden_size": 8,
  "initializer_range": 0.02,
  "intermediate_size": 8,
  "max_position_embeddings": 32768,
  "model_type": "mistral",
  "num_attention_heads": 4,
  "num_hidden_layers": 1,
  "num_key_value_heads": 2,
  "output_hidden_states": true,
  "rms_norm_eps": 1e-05,
  "rope_theta": 10000.0,
  "sliding_window": 3,
  "tie_word_embeddings": false,
  "torch_dtype": "bfloat16",
  "transformers_version": "4.38.0.dev0",
  "use_cache": false,
  "vocab_size": 32000
}

In [3]:
from transformers import AutoModel

tinymistral = AutoModel.from_config(mistral_config)

In [4]:
from transformers import AutoTokenizer

src_sent = "hi how are you doing"

mistal_tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")

In [5]:
tokenized_src_dict = mistal_tokenizer.encode_plus(src_sent, return_tensors='pt')
tokenized_src_dict

{'input_ids': tensor([[    1, 12014,   910,   460,   368,  2548]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1]])}

In [6]:
tokenized_src_dict = mistal_tokenizer.encode_plus(src_sent, return_tensors='pt')
tokenized_src_dict

{'input_ids': tensor([[    1, 12014,   910,   460,   368,  2548]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1]])}

In [7]:
src_tokenized = tokenized_src_dict["input_ids"]
src_tokenized

tensor([[    1, 12014,   910,   460,   368,  2548]])

In [8]:
mistal_tokenizer.decode(*src_tokenized)

'<s> hi how are you doing'

In [9]:
seq_length = src_tokenized.shape[1]
sliding_window_len = 3

sliding_window_mask = (torch.triu(torch.ones(seq_length, seq_length), diagonal=1)).bool()

# print(sliding_window_mask)

for i in range(sliding_window_mask.shape[0]-1, -1, -1):

    li = i - sliding_window_len + 1

    if li > 0:

        sliding_window_mask[i][0:li] = True

sliding_window_mask

tensor([[False,  True,  True,  True,  True,  True],
        [False, False,  True,  True,  True,  True],
        [False, False, False,  True,  True,  True],
        [ True, False, False, False,  True,  True],
        [ True,  True, False, False, False,  True],
        [ True,  True,  True, False, False, False]])

In [10]:
sliding_window_mask = sliding_window_mask.unsqueeze(0)
sliding_window_mask = sliding_window_mask.unsqueeze(0)
sliding_window_mask.shape, sliding_window_mask

(torch.Size([1, 1, 6, 6]),
 tensor([[[[False,  True,  True,  True,  True,  True],
           [False, False,  True,  True,  True,  True],
           [False, False, False,  True,  True,  True],
           [ True, False, False, False,  True,  True],
           [ True,  True, False, False, False,  True],
           [ True,  True,  True, False, False, False]]]]))

In [11]:
from pprint import pprint

tokenized_src_dict["attention_mask"] = sliding_window_mask

pprint(tokenized_src_dict)


{'attention_mask': tensor([[[[False,  True,  True,  True,  True,  True],
          [False, False,  True,  True,  True,  True],
          [False, False, False,  True,  True,  True],
          [ True, False, False, False,  True,  True],
          [ True,  True, False, False, False,  True],
          [ True,  True,  True, False, False, False]]]]),
 'input_ids': tensor([[    1, 12014,   910,   460,   368,  2548]])}


In [12]:
output = tinymistral(**tokenized_src_dict)

sdpa

###########################################################################
LLAMA DECODER FWD START

Attention mask =  tensor([[[[-3.4028e+38,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
            0.0000e+00],
          [-3.4028e+38, -3.4028e+38,  0.0000e+00,  0.0000e+00,  0.0000e+00,
            0.0000e+00],
          [-3.4028e+38, -3.4028e+38, -3.4028e+38,  0.0000e+00,  0.0000e+00,
            0.0000e+00],
          [ 0.0000e+00, -3.4028e+38, -3.4028e+38, -3.4028e+38,  0.0000e+00,
            0.0000e+00],
          [ 0.0000e+00,  0.0000e+00, -3.4028e+38, -3.4028e+38, -3.4028e+38,
            0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00, -3.4028e+38, -3.4028e+38,
           -3.4028e+38]]]])

Input (hidden states) =  tensor([[[-0.0074, -0.0254,  0.0067,  0.0112,  0.0315,  0.0374,  0.0165,
          -0.0181],
         [-0.0131,  0.0218,  0.0225,  0.0525,  0.0102,  0.0363,  0.0032,
          -0.0137],
         [-0.0309,  0.0209, -0.0214,  0.0265, -0.0147,

## **SLIDING WINDOW ATTENTION** and **GROUPED-QUERY ATTENTION**

In [13]:
mistral_config.num_key_value_heads

2

In [14]:
# head_dim = 4
# num_heads = 1
# seq_len = 4

embed_dim = 8

num_heads = 4

n_heads_q = num_heads

n_kv_heads = 2

seq_len = 6

head_dim = embed_dim // num_heads

n_rep = n_heads_q//n_kv_heads
print("n_rep = ", n_rep)

n_rep =  2


In [15]:
state_dict = tinymistral.state_dict()

In [16]:
# state_dict

In [17]:
import torch

def look_up_table(sentence, vocab_embeds, embedding):

    for i in range(sentence.size(0)):
        for j in range(sentence.size(1)):
            
            # Get the index for the current word token index in the sequence
            word_index = sentence[i, j].item()

            if word_index < 0 or word_index >= vocab_embeds.size(0):
                raise ValueError(f"Invalid word index: {word_index}")

            # Lookup the corresponding embedding vector for the word
            embedding[i, j, :] = vocab_embeds[word_index, :]

            print(f"Word index: {word_index}, Embedding: {vocab_embeds[word_index, :]}")
    print()

    return embedding

In [18]:
def get_embedding_outputs(src_tokens, state_dict, d_model):

    src_vocab_embeds = state_dict["embed_tokens.weight"]

    src_embedding = torch.zeros(src_tokens.size(0), src_tokens.size(1), d_model)
    print("Source sentence embedding")
    src_embedding =  look_up_table(src_tokens, src_vocab_embeds, src_embedding)
    print(src_embedding.shape)


    print("Source embeddings : \n")
    print(src_embedding)

    return src_embedding


input_embeddings = get_embedding_outputs(src_tokenized, state_dict, d_model = embed_dim)

Source sentence embedding
Word index: 1, Embedding: tensor([-0.0074, -0.0254,  0.0067,  0.0112,  0.0315,  0.0374,  0.0165, -0.0181])
Word index: 12014, Embedding: tensor([-0.0131,  0.0218,  0.0225,  0.0525,  0.0102,  0.0363,  0.0032, -0.0137])
Word index: 910, Embedding: tensor([-0.0309,  0.0209, -0.0214,  0.0265, -0.0147, -0.0243,  0.0350,  0.0250])
Word index: 460, Embedding: tensor([-0.0138, -0.0130,  0.0139,  0.0085, -0.0049,  0.0009, -0.0071,  0.0084])
Word index: 368, Embedding: tensor([ 0.0376, -0.0126,  0.0067, -0.0009,  0.0185, -0.0372, -0.0128,  0.0204])
Word index: 2548, Embedding: tensor([-0.0160,  0.0149,  0.0070, -0.0099, -0.0118, -0.0123,  0.0177,  0.0366])

torch.Size([1, 6, 8])
Source embeddings : 

tensor([[[-0.0074, -0.0254,  0.0067,  0.0112,  0.0315,  0.0374,  0.0165,
          -0.0181],
         [-0.0131,  0.0218,  0.0225,  0.0525,  0.0102,  0.0363,  0.0032,
          -0.0137],
         [-0.0309,  0.0209, -0.0214,  0.0265, -0.0147, -0.0243,  0.0350,
           0.02

In [19]:
residual = input_embeddings

In [20]:
def apply_layernorm(hidden_states, wt, variance_epsilon = 1e-05):

    dtype =  hidden_states.dtype
    hidden_states = hidden_states.to(torch.float32)
    variance = hidden_states.pow(2).mean(-1, keepdim=True)
    hidden_states = hidden_states * torch.rsqrt(variance + variance_epsilon)
    op = wt * hidden_states
    return op.to(dtype)

In [21]:
hidden_state = apply_layernorm(input_embeddings, state_dict["layers.0.input_layernorm.weight"], variance_epsilon = 1e-05)

hidden_state


tensor([[[-0.3321, -1.1450,  0.3039,  0.5057,  1.4174,  1.6856,  0.7422,
          -0.8167],
         [-0.4934,  0.8224,  0.8510,  1.9829,  0.3837,  1.3713,  0.1228,
          -0.5167],
         [-1.2029,  0.8125, -0.8324,  1.0307, -0.5720, -0.9447,  1.3599,
           0.9734],
         [-1.3384, -1.2590,  1.3482,  0.8278, -0.4802,  0.0843, -0.6869,
           0.8108],
         [ 1.6819, -0.5647,  0.2996, -0.0414,  0.8280, -1.6612, -0.5732,
           0.9095],
         [-0.8816,  0.8184,  0.3826, -0.5416, -0.6470, -0.6772,  0.9711,
           2.0128]]])

In [22]:
state_dict.keys()

odict_keys(['embed_tokens.weight', 'layers.0.self_attn.q_proj.weight', 'layers.0.self_attn.k_proj.weight', 'layers.0.self_attn.v_proj.weight', 'layers.0.self_attn.o_proj.weight', 'layers.0.mlp.gate_proj.weight', 'layers.0.mlp.up_proj.weight', 'layers.0.mlp.down_proj.weight', 'layers.0.input_layernorm.weight', 'layers.0.post_attention_layernorm.weight', 'norm.weight'])

In [23]:
def get_qkv(hidden_state ,Wq, Wk, Wv):


    q_matmul = hidden_state@Wq.T
    k_matmul = hidden_state@Wk.T
    v_matmul = hidden_state@Wv.T

    return q_matmul, k_matmul, v_matmul
    

In [24]:
Wq = state_dict["layers.0.self_attn.q_proj.weight"]
Wk = state_dict["layers.0.self_attn.k_proj.weight"]
Wv = state_dict["layers.0.self_attn.v_proj.weight"]
query, key, value = get_qkv(hidden_state ,Wq, Wk, Wv)

In [25]:
query, key, value

(tensor([[[-0.0113, -0.0217, -0.0369, -0.0231,  0.0018,  0.0096, -0.0120,
            0.0402],
          [-0.0510, -0.0522, -0.0510, -0.0357, -0.0565,  0.0447,  0.0820,
           -0.0111],
          [-0.0135,  0.0034,  0.0354, -0.0907, -0.0300,  0.0246,  0.0032,
           -0.0650],
          [ 0.0072,  0.0276, -0.0701,  0.0316,  0.0627,  0.1072, -0.0125,
            0.0886],
          [ 0.0131,  0.0458,  0.0538,  0.0604, -0.0282,  0.0249, -0.0474,
           -0.0304],
          [-0.0426,  0.0593,  0.0134, -0.0548,  0.0401,  0.0886,  0.0021,
           -0.0048]]]),
 tensor([[[-0.0682,  0.0924, -0.0166,  0.0247],
          [-0.0963, -0.0072, -0.0289,  0.0196],
          [-0.0226, -0.1110, -0.0446,  0.0053],
          [ 0.0608, -0.0750, -0.0111, -0.0111],
          [ 0.0157,  0.0671,  0.0796, -0.0588],
          [ 0.0609, -0.1770, -0.0375,  0.0568]]]),
 tensor([[[-0.0372, -0.0752, -0.0424,  0.0104],
          [-0.0334, -0.1026,  0.0050, -0.0126],
          [-0.0387,  0.0184,  0.0497, -0

In [26]:
def get_sin_cos(dim, seq_len, max_seq_len, base = 10000):

    inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim))

    t = torch.arange(max_seq_len, dtype=torch.int64).type_as(inv_freq)

    freqs = torch.outer(t, inv_freq)
    
    # Uses a different permutation in order to obtain the same calculation
    emb = torch.cat((freqs, freqs), dim=-1)    

    return  emb.cos()[:seq_len], emb.sin()[:seq_len]


def rotate_half(x):

    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]

    # print(x1.shape, x2.shape)
    # x1 = x[ : , : x.shape[-1] // 2]
    # x2 = x[ : , x.shape[-1] // 2 :]

    return torch.cat((-x2, x1), dim=-1)

def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):

    cos = cos.unsqueeze(unsqueeze_dim)
    sin = sin.unsqueeze(unsqueeze_dim)

    # print("HALF ROT SHAPES = ")
    # print((rotate_half(q).shape,   sin.shape))
    # print((rotate_half(k).shape , sin.shape))

    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)

    return q_embed, k_embed

In [27]:
query.shape, key.shape, value.shape

(torch.Size([1, 6, 8]), torch.Size([1, 6, 4]), torch.Size([1, 6, 4]))

In [28]:
bsz, q_len, _ = query.shape

query = query.view(bsz, q_len, n_heads_q, head_dim).transpose(1, 2)
key = key.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)
value = value.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)

In [29]:
query.shape, key.shape, value.shape

(torch.Size([1, 4, 6, 2]), torch.Size([1, 2, 6, 2]), torch.Size([1, 2, 6, 2]))

In [30]:
def get_Rope(query, key, head_dim, seq_len,  num_heads):

    cos, sin = get_sin_cos(dim = head_dim, seq_len = seq_len, max_seq_len = 2048, base = 10000)

    cos = cos.unsqueeze(0)
    sin = sin.unsqueeze(0)

    q_rotated, k_rotated = apply_rotary_pos_emb(query, key, cos, sin, unsqueeze_dim=1)

    return q_rotated, k_rotated
    

In [31]:
query_rotated, key_rotated = get_Rope(query, key, head_dim, seq_len,  num_heads)

In [32]:
query_rotated, key_rotated

(tensor([[[[-0.0113, -0.0217],
           [ 0.0164, -0.0711],
           [ 0.0026, -0.0137],
           [-0.0110, -0.0263],
           [ 0.0261, -0.0399],
           [ 0.0448,  0.0577]],
 
          [[-0.0369, -0.0231],
           [ 0.0025, -0.0622],
           [ 0.0677,  0.0699],
           [ 0.0649, -0.0412],
           [ 0.0105, -0.0802],
           [-0.0488, -0.0284]],
 
          [[ 0.0018,  0.0096],
           [-0.0681, -0.0234],
           [-0.0098, -0.0375],
           [-0.0772, -0.0972],
           [ 0.0373,  0.0051],
           [ 0.0964, -0.0133]],
 
          [[-0.0120,  0.0402],
           [ 0.0536,  0.0630],
           [ 0.0578,  0.0300],
           [-0.0002, -0.0895],
           [ 0.0079,  0.0557],
           [-0.0040, -0.0034]]]]),
 tensor([[[[-0.0682,  0.0924],
           [-0.0460, -0.0850],
           [ 0.1103,  0.0256],
           [-0.0496,  0.0828],
           [ 0.0406, -0.0557],
           [-0.1524, -0.1085]],
 
          [[-0.0166,  0.0247],
           [-0.0321, -0

In [33]:
value

tensor([[[[-0.0372, -0.0752],
          [-0.0334, -0.1026],
          [-0.0387,  0.0184],
          [-0.0112, -0.0671],
          [-0.0131,  0.0943],
          [ 0.1154,  0.0390]],

         [[-0.0424,  0.0104],
          [ 0.0050, -0.0126],
          [ 0.0497, -0.0733],
          [-0.0189, -0.0172],
          [-0.0999, -0.0100],
          [ 0.0345, -0.0383]]]])

In [34]:
def repeat_kv(x, n_rep):

    bsz, num_key_value_heads, seq_len, head_dim = x.shape

    if n_rep == 1:
        return x
    
    x = x[:, :, None, :, :].expand(bsz, num_key_value_heads, n_rep, seq_len, head_dim)
    return x.reshape(bsz, num_key_value_heads * n_rep, seq_len, head_dim)


In [35]:
key_rotated = repeat_kv(key_rotated, n_rep)
value = repeat_kv(value, n_rep)

In [36]:
query.shape, key_rotated.shape, value.shape

(torch.Size([1, 4, 6, 2]), torch.Size([1, 4, 6, 2]), torch.Size([1, 4, 6, 2]))

In [37]:
import math

def self_attention_rope(query, key, value, attn_mask = None, scale = None, is_causal=False):

    L, S = query.size(-2), key.size(-2)

    scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
    attn_bias = torch.zeros(L, S, dtype=query.dtype)

    if is_causal:
        assert attn_mask is None
        temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
        attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
        attn_bias.to(query.dtype)

    
    if attn_mask is not None:
        if attn_mask.dtype == torch.bool:
            attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
        else:
            attn_bias += attn_mask



    # (bsz, num_heads, tgt_len, head_dim) @ (bsz, num_heads, head_dim, tgt_len) -> (bsz, num_heads, tgt_len, tgt_len) 
    attn_weight = query @ key.transpose(-2, -1) * scale_factor
    attn_weight += attn_bias

    # (bsz, num_heads, tgt_len, tgt_len) 
    attn_weight = torch.softmax(attn_weight, dim=-1)


    sum_last_dim = attn_weight.sum(dim=-1)
    tolerance = 1e-6  
    assert torch.allclose(sum_last_dim, torch.ones_like(sum_last_dim), atol=tolerance), "Attention weights sum is not approximately equal to 1"


    # # (bsz, num_heads, tgt_len, tgt_len) @ (bsz, num_heads, tgt_len, head_dim) -> (bsz, num_heads, tgt_len, head_dim) 
    attn_output = attn_weight @ value

    print("ATTEN OUTPUT = ", attn_output)

    return attn_output


In [38]:
seq_length = src_tokenized.shape[1]
sliding_window_len = 3

sliding_window_mask =  (1 - torch.triu(torch.ones(seq_length, seq_length), diagonal=1)).bool()

# print(sliding_window_mask)

for i in range(sliding_window_mask.shape[0]-1, -1, -1):

    li = i - sliding_window_len + 1

    if li > 0:

        sliding_window_mask[i][0:li] = False

sliding_window_mask

tensor([[ True, False, False, False, False, False],
        [ True,  True, False, False, False, False],
        [ True,  True,  True, False, False, False],
        [False,  True,  True,  True, False, False],
        [False, False,  True,  True,  True, False],
        [False, False, False,  True,  True,  True]])

In [39]:
self_attn_op = self_attention_rope(query_rotated, key_rotated, value, attn_mask = sliding_window_mask, is_causal=False)


ATTEN OUTPUT =  tensor([[[[-0.0372, -0.0752],
          [-0.0353, -0.0889],
          [-0.0364, -0.0531],
          [-0.0278, -0.0505],
          [-0.0210,  0.0153],
          [ 0.0301,  0.0219]],

         [[-0.0372, -0.0752],
          [-0.0353, -0.0889],
          [-0.0364, -0.0529],
          [-0.0278, -0.0503],
          [-0.0210,  0.0154],
          [ 0.0306,  0.0221]],

         [[-0.0424,  0.0104],
          [-0.0187, -0.0011],
          [ 0.0041, -0.0252],
          [ 0.0120, -0.0344],
          [-0.0230, -0.0335],
          [-0.0279, -0.0219]],

         [[-0.0424,  0.0104],
          [-0.0187, -0.0011],
          [ 0.0041, -0.0252],
          [ 0.0120, -0.0344],
          [-0.0230, -0.0335],
          [-0.0281, -0.0218]]]])


In [40]:
self_attn_op.shape

torch.Size([1, 4, 6, 2])

In [41]:
def out_proj_self_attn(self_attn_op, W, embed_dim):

    self_attn_op = self_attn_op.transpose(1, 2).contiguous()
    self_attn_op = self_attn_op.reshape(bsz, q_len, embed_dim)
    return self_attn_op@W.T

In [42]:
Wo = state_dict["layers.0.self_attn.o_proj.weight"]
print(Wo.shape)

sa_output = out_proj_self_attn(self_attn_op, Wo, embed_dim)

torch.Size([8, 8])


In [43]:
sa_output

# check and verified 

tensor([[[-5.4631e-04, -4.2851e-04, -2.7580e-03,  3.4945e-04,  1.8132e-03,
          -9.3882e-05, -1.9177e-03, -1.4553e-05],
         [ 4.6210e-04, -6.6375e-04, -3.0637e-03,  4.6621e-04,  3.0429e-03,
          -2.9930e-04,  3.0276e-04, -2.4526e-03],
         [ 1.6153e-03, -2.3367e-04, -1.4735e-03,  5.7002e-04,  2.9275e-03,
          -7.5637e-04,  2.8188e-03, -3.0160e-03],
         [ 2.0543e-03, -1.8132e-04, -1.3591e-03, -6.8881e-05,  3.3867e-03,
          -1.0794e-03,  3.4974e-03, -3.6740e-03],
         [ 8.1219e-04,  8.0704e-04,  7.4197e-04, -1.1923e-03,  6.8060e-04,
          -1.2882e-03,  3.6161e-04,  1.4475e-03],
         [-1.1690e-04,  7.3303e-04,  4.5419e-04, -3.7673e-03,  3.0093e-04,
          -1.3759e-03, -2.1586e-03,  2.1960e-03]]])

In [44]:
hidden_states = residual + sa_output

hidden_states


tensor([[[-0.0079, -0.0258,  0.0040,  0.0116,  0.0333,  0.0373,  0.0146,
          -0.0181],
         [-0.0126,  0.0211,  0.0195,  0.0529,  0.0132,  0.0360,  0.0036,
          -0.0161],
         [-0.0293,  0.0207, -0.0229,  0.0271, -0.0118, -0.0251,  0.0378,
           0.0220],
         [-0.0117, -0.0132,  0.0125,  0.0085, -0.0016, -0.0002, -0.0036,
           0.0047],
         [ 0.0384, -0.0118,  0.0074, -0.0021,  0.0192, -0.0385, -0.0125,
           0.0218],
         [-0.0162,  0.0156,  0.0074, -0.0136, -0.0115, -0.0137,  0.0155,
           0.0388]]])

In [45]:
residual = hidden_states

In [46]:
hidden_state = apply_layernorm(hidden_states, state_dict["layers.0.post_attention_layernorm.weight"], variance_epsilon = 1e-05)

hidden_state

tensor([[[-0.3540, -1.1556,  0.1783,  0.5175,  1.4878,  1.6687,  0.6509,
          -0.8112],
         [-0.4757,  0.7970,  0.7350,  1.9997,  0.4985,  1.3595,  0.1342,
          -0.6091],
         [-1.1389,  0.8025, -0.8887,  1.0517, -0.4577, -0.9731,  1.4679,
           0.8553],
         [-1.2954, -1.4519,  1.3833,  0.9339, -0.1723, -0.0233, -0.3952,
           0.5166],
         [ 1.6706, -0.5140,  0.3236, -0.0921,  0.8347, -1.6713, -0.5417,
           0.9473],
         [-0.8497,  0.8217,  0.3899, -0.7164, -0.6033, -0.7204,  0.8158,
           2.0416]]])

In [47]:
# Only when 'pretraining_tp == 1'
import torch.nn as nn


def llama_mlp(x, W_down_proj,W_up_proj, W_gate_proj):
    
    up_proj =  x@W_up_proj.T
    gate_proj =  x@W_gate_proj.T

    print("UP PROJ = ", up_proj)
    print()

    print("GATE PROJ = ", gate_proj)
    print()

    temp_proj = up_proj * gate_proj

    silu = nn.SiLU()

    temp_proj = silu(temp_proj)

    print("ACT = ",temp_proj)

    down_proj = temp_proj@W_down_proj.T

    return down_proj

In [48]:
W_down_proj = state_dict["layers.0.mlp.down_proj.weight"]

W_up_proj = state_dict["layers.0.mlp.up_proj.weight"]

W_gate_proj = state_dict["layers.0.mlp.gate_proj.weight"]

In [49]:
hidden_state = llama_mlp(hidden_state, W_down_proj,W_up_proj, W_gate_proj)
hidden_state

UP PROJ =  tensor([[[-0.0263, -0.0374,  0.0362,  0.0202,  0.0470,  0.0156, -0.0283,
           0.0319],
         [-0.1024,  0.0197, -0.0321, -0.0052,  0.0641,  0.0524, -0.0015,
           0.0010],
         [-0.0693, -0.0265, -0.0288,  0.0949,  0.0505,  0.0963, -0.0361,
           0.0130],
         [-0.0328,  0.0717, -0.0256, -0.0183,  0.0024,  0.0042, -0.0562,
          -0.0220],
         [ 0.0753,  0.0299,  0.0040, -0.0840, -0.0625, -0.0627, -0.1109,
          -0.0362],
         [-0.0098,  0.0205, -0.0038, -0.0214,  0.0036,  0.0421, -0.0381,
          -0.0661]]])

GATE PROJ =  tensor([[[-3.3223e-02,  1.1813e-01, -3.4386e-02, -1.7299e-02, -2.8870e-02,
          -3.1325e-02, -9.3374e-03, -3.7493e-02],
         [-3.8423e-02,  7.7386e-02, -7.7597e-02, -1.1311e-01, -3.3710e-03,
          -5.1375e-02,  1.0571e-04, -1.0072e-01],
         [ 7.4562e-02, -1.8549e-02,  7.8605e-03,  4.1163e-02, -1.6428e-02,
          -3.1222e-02, -3.1268e-02,  1.4060e-02],
         [ 9.3609e-03, -3.4470e-02, -3.7

tensor([[[-1.1604e-04,  1.0476e-04,  5.0113e-05,  8.0561e-06,  3.7039e-05,
           6.7447e-05,  4.9149e-07, -2.5759e-06],
         [ 4.9020e-06,  1.3651e-05,  2.6720e-05, -9.1400e-05,  2.5665e-05,
          -4.4363e-05,  9.8443e-06, -1.4415e-05],
         [-5.1641e-05, -1.4457e-05, -7.2304e-05,  1.1696e-05, -6.9724e-05,
           5.1070e-05, -5.3767e-05, -4.8957e-05],
         [-1.6031e-05,  1.7560e-05,  9.4520e-06, -4.8031e-05, -1.9137e-05,
          -2.3326e-05, -2.9325e-05, -2.1525e-05],
         [ 1.5106e-05,  1.4085e-05,  1.0385e-04, -2.3431e-05, -7.3604e-06,
           6.7487e-05,  7.6150e-05, -1.7018e-06],
         [-1.2456e-05,  2.0832e-05,  8.2834e-06, -4.5761e-06, -2.8004e-05,
          -5.5918e-05,  2.3039e-05, -1.7661e-05]]])

In [50]:
hidden_state = hidden_state + residual

In [51]:
hidden_state

tensor([[[-0.0080, -0.0257,  0.0040,  0.0116,  0.0333,  0.0374,  0.0146,
          -0.0181],
         [-0.0126,  0.0211,  0.0195,  0.0529,  0.0132,  0.0360,  0.0036,
          -0.0161],
         [-0.0294,  0.0207, -0.0230,  0.0271, -0.0119, -0.0250,  0.0377,
           0.0220],
         [-0.0118, -0.0131,  0.0125,  0.0084, -0.0016, -0.0002, -0.0036,
           0.0047],
         [ 0.0385, -0.0118,  0.0076, -0.0021,  0.0192, -0.0384, -0.0124,
           0.0218],
         [-0.0162,  0.0156,  0.0074, -0.0136, -0.0115, -0.0138,  0.0155,
           0.0388]]])