In [1]:
from transformers.trainer_utils import set_seed
import torch
SEED = 6
set_seed(SEED)

In [2]:
from transformers import AutoConfig, MistralConfig


mistral_config = MistralConfig.from_pretrained("openaccess-ai-collective/tiny-mistral", num_hidden_layers = 1, use_cache = False, hidden_size = 8, num_attention_heads = 4, 
                                           output_hidden_states=True,  num_key_value_heads = 2, past_key_values = True, intermediate_size = 8, sliding_window = 3, dropout_p = 0)


mistral_config

MistralConfig {
  "_name_or_path": "./tiny-mistral",
  "architectures": [
    "MistralForCausalLM"
  ],
  "attention_dropout": 0.0,
  "bos_token_id": 1,
  "dropout_p": 0,
  "eos_token_id": 2,
  "hidden_act": "silu",
  "hidden_size": 8,
  "initializer_range": 0.02,
  "intermediate_size": 8,
  "max_position_embeddings": 32768,
  "model_type": "mistral",
  "num_attention_heads": 4,
  "num_hidden_layers": 1,
  "num_key_value_heads": 2,
  "output_hidden_states": true,
  "rms_norm_eps": 1e-05,
  "rope_theta": 10000.0,
  "sliding_window": 3,
  "tie_word_embeddings": false,
  "torch_dtype": "bfloat16",
  "transformers_version": "4.38.0.dev0",
  "use_cache": false,
  "vocab_size": 32000
}

In [3]:
from transformers import AutoModel

tinymistral = AutoModel.from_config(mistral_config)

In [4]:
from transformers import AutoTokenizer

src_sent = "hi how are you doing"

mistal_tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")

In [5]:
tokenized_src_dict = mistal_tokenizer.encode_plus(src_sent, return_tensors='pt')
tokenized_src_dict

{'input_ids': tensor([[    1, 12014,   910,   460,   368,  2548]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1]])}

In [6]:
tokenized_src_dict = mistal_tokenizer.encode_plus(src_sent, return_tensors='pt')
tokenized_src_dict

{'input_ids': tensor([[    1, 12014,   910,   460,   368,  2548]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1]])}

In [7]:
src_tokenized = tokenized_src_dict["input_ids"]
src_tokenized

tensor([[    1, 12014,   910,   460,   368,  2548]])

In [8]:
mistal_tokenizer.decode(*src_tokenized)

'<s> hi how are you doing'

In [29]:
src_tokenized_np = src_tokenized.numpy()[0]
src_tokenized_np.shape

(6,)

# Mistral 

## **SLIDING WINDOW ATTENTION** and **GROUPED-QUERY ATTENTION**

In [30]:
mistral_config.num_key_value_heads

2

### Defining hyperparameters

In [31]:
embed_dim = d_model = 8

num_heads = 4

n_heads_q = num_heads

n_kv_heads = 2

seq_len = 6

head_dim = embed_dim // num_heads

n_rep = n_heads_q//n_kv_heads
print("n_rep = ", n_rep)

n_rep =  2


In [32]:
state_dict = tinymistral.state_dict()

### 1. Source token embeddings 

In [34]:
import numpy as np


src_vocab_embeds = state_dict["embed_tokens.weight"].numpy()

src_embedding = np.zeros((src_tokenized_np.shape[0], d_model))

for i in range(src_tokenized_np.shape[0]):
        word_index = src_tokenized_np[i]
        if word_index < 0 or word_index >= src_vocab_embeds.shape[0]:
            raise ValueError(f"Invalid word index: {word_index}")
        src_embedding[i, :] = src_vocab_embeds[word_index, :]

        print(f"Word index: {word_index}, Embedding: {src_vocab_embeds[word_index, :]}")
print()
print(src_embedding.shape)


Word index: 1, Embedding: [-0.00737046 -0.02541412  0.00674542  0.01122305  0.0314588   0.0374115
  0.01647354 -0.0181274 ]
Word index: 12014, Embedding: [-0.01305792  0.02176759  0.02252432  0.05248258  0.0101565   0.03629604
  0.00324979 -0.01367457]
Word index: 910, Embedding: [-0.03094564  0.02090054 -0.02141277  0.02651504 -0.01471478 -0.02430361
  0.03498438  0.02504197]
Word index: 460, Embedding: [-0.01378937 -0.01297163  0.01389042  0.00852889 -0.00494728  0.00086873
 -0.00707719  0.00835401]
Word index: 368, Embedding: [ 0.03763684 -0.01263603  0.00670438 -0.00092739  0.01852957 -0.03717487
 -0.01282768  0.02035378]
Word index: 2548, Embedding: [-0.01604516  0.01489605  0.00696284 -0.00985821 -0.0117752  -0.01232624
  0.01767547  0.03663577]

(6, 8)


In [35]:
residual = src_embedding

### 2. Pre-normalization (RMSNorm)

In [36]:
variance_epsilon = 1e-05

wt = state_dict["layers.0.input_layernorm.weight"].numpy()

dtype = src_embedding.dtype
src_embedding = src_embedding.astype(np.float32)
variance = np.mean(src_embedding**2, axis=-1, keepdims=True)
src_embedding = src_embedding * (1/np.sqrt(variance + variance_epsilon))

hidden_state = wt * src_embedding

In [37]:
hidden_state, hidden_state.shape

(array([[-0.33207855, -1.1450424 ,  0.30391717,  0.5056585 ,  1.4173874 ,
          1.6855887 ,  0.7422213 , -0.8167365 ],
        [-0.49335742,  0.8224283 ,  0.85101897,  1.982909  ,  0.3837353 ,
          1.3713453 ,  0.12278415, -0.51665574],
        [-1.2029388 ,  0.81245935, -0.83237123,  1.0307097 , -0.57200253,
         -0.94474554,  1.3599354 ,  0.9734475 ],
        [-1.3383939 , -1.2590235 ,  1.3482019 ,  0.82781243, -0.48018178,
          0.08431846, -0.68691105,  0.810839  ],
        [ 1.6818779 , -0.5646665 ,  0.2995986 , -0.0414422 ,  0.8280312 ,
         -1.661234  , -0.5732305 ,  0.9095496 ],
        [-0.881556  ,  0.8184212 ,  0.3825538 , -0.5416313 , -0.6469551 ,
         -0.6772302 ,  0.9711286 ,  2.0128489 ]], dtype=float32),
 (6, 8))

### 3. Getting the Q,K and V matrices for attention calcualtion

In [38]:
Wq = state_dict["layers.0.self_attn.q_proj.weight"].numpy()
Wk = state_dict["layers.0.self_attn.k_proj.weight"].numpy()
Wv = state_dict["layers.0.self_attn.v_proj.weight"].numpy()


query = np.matmul(hidden_state, Wq.T)
key = np.matmul(hidden_state, Wk.T)
value = np.matmul(hidden_state, Wv.T)

In [39]:
query, key, value

(array([[-0.01127685, -0.02173477, -0.03691161, -0.02306341,  0.00182827,
          0.00958379, -0.01198711,  0.04022793],
        [-0.05099525, -0.05224977, -0.05096808, -0.03565389, -0.05645479,
          0.04469625,  0.0819565 , -0.01108029],
        [-0.01349053,  0.00335787,  0.03539798, -0.09065476, -0.03000753,
          0.02455559,  0.0032218 , -0.06504469],
        [ 0.00718971,  0.02755949, -0.07008822,  0.03163479,  0.06265695,
          0.10716037, -0.01247996,  0.08862031],
        [ 0.01310077,  0.04581054,  0.05383541,  0.06036859, -0.02822899,
          0.02486397, -0.04735787, -0.03042608],
        [-0.04263002,  0.05932804,  0.01338688, -0.05481394,  0.04012501,
          0.08863796,  0.00213639, -0.00484931]], dtype=float32),
 array([[-0.068159  ,  0.09239922, -0.01655504,  0.0246703 ],
        [-0.09634116, -0.00719259, -0.02887959,  0.01964808],
        [-0.02263346, -0.11099781, -0.04456816,  0.00532937],
        [ 0.06082123, -0.0749862 , -0.01108772, -0.011057  

In [41]:
print("query_shape = ", query.shape)
print("key_shape = ", key.shape)
print("value_shape = ", value.shape)
print()

query_shape =  (6, 8)
key_shape =  (6, 4)
value_shape =  (6, 4)



In [40]:
q_len, _ = query.shape


print("After reshaping... \n")

query1 = np.transpose(np.reshape(query, (q_len, num_heads, head_dim)), (1, 0, 2))
key1 = np.transpose(np.reshape(key, (q_len, n_kv_heads, head_dim)), (1, 0, 2))
value1 = np.transpose(np.reshape(value, (q_len, n_kv_heads, head_dim)), (1, 0, 2))

print("query_shape = ", query1.shape)
print("key_shape = ", key1.shape)
print("value_shape = ", value1.shape)
print()

After reshaping... 

query_shape =  (4, 6, 2)
key_shape =  (2, 6, 2)
value_shape =  (2, 6, 2)



### 4. Obtaining the rotary embeddings 

#### 4.1 Pre-computing the sin and cos values

In [42]:
base = 10000
max_seq_len = 2048
dim = head_dim


inv_freq = 1.0 / (base ** (np.arange(0, dim, 2, dtype=np.float32) / dim))
t = np.arange(max_seq_len, dtype=np.float32)
freqs = np.outer(t, inv_freq)
# emb = np.concatenate((np.cos(freqs), np.sin(freqs)), axis=-1)
emb = np.concatenate((freqs,freqs), axis=-1)

cos, sin =  np.cos(emb[:seq_len]), np.sin(emb[:seq_len])

In [43]:
cos, sin

(array([[ 1.        ,  1.        ],
        [ 0.5403023 ,  0.5403023 ],
        [-0.4161468 , -0.4161468 ],
        [-0.9899925 , -0.9899925 ],
        [-0.6536436 , -0.6536436 ],
        [ 0.28366217,  0.28366217]], dtype=float32),
 array([[ 0.       ,  0.       ],
        [ 0.841471 ,  0.841471 ],
        [ 0.9092974,  0.9092974],
        [ 0.14112  ,  0.14112  ],
        [-0.7568025, -0.7568025],
        [-0.9589243, -0.9589243]], dtype=float32))

#### 4.2 Applying the rotations on the Q and K matrices 

In [44]:
# Q matrix rotation

unsqueeze_dim = 0

cos_exp = np.expand_dims(cos, axis=unsqueeze_dim)
sin_exp = np.expand_dims(sin, axis=unsqueeze_dim)


# Half rotation 
q1 = query1[..., :query1.shape[-1] // 2]
q2 = query1[..., query1.shape[-1] // 2:]
q_half_rot = np.concatenate((-q2, q1), axis=-1)


query_rotated = query1*cos_exp + q_half_rot*sin_exp


In [45]:
# K matrix rotation

unskueeze_dim = 0

cos_exp = np.expand_dims(cos, axis=unskueeze_dim)
sin_exp = np.expand_dims(sin, axis=unskueeze_dim)


# Half rotation 
k1 = key1[..., :key1.shape[-1] // 2]
k2 = key1[..., key1.shape[-1] // 2:]
key_half_rot = np.concatenate((-k2, k1), axis=-1)


key_rotated = key1*cos_exp + key_half_rot*sin_exp


In [46]:
query_rotated, key_rotated

(array([[[-0.01127685, -0.02173477],
         [ 0.01641382, -0.07114169],
         [ 0.00256074, -0.01366427],
         [-0.01100695, -0.02626907],
         [ 0.0261063 , -0.03985846],
         [ 0.04479858,  0.05770808]],
 
        [[-0.03691161, -0.02306341],
         [ 0.00246354, -0.06215204],
         [ 0.06770138,  0.06991298],
         [ 0.06492251, -0.04120906],
         [ 0.01049793, -0.08020231],
         [-0.04876507, -0.02838565]],
 
        [[ 0.00182827,  0.00958379],
         [-0.06811325, -0.02335558],
         [-0.00984079, -0.0375045 ],
         [-0.07715238, -0.09724581],
         [ 0.03726882,  0.00511159],
         [ 0.09637903, -0.01333361]],
 
        [[-0.01198711,  0.04022793],
         [ 0.05360502,  0.06297731],
         [ 0.05780423,  0.02999771],
         [-0.00015103, -0.08949462],
         [ 0.00792864,  0.05572837],
         [-0.00404411, -0.0034242 ]]], dtype=float32),
 array([[[-0.068159  ,  0.09239922],
         [-0.04600099, -0.08495446],
         [ 

In [47]:
value1

array([[[-0.03718761, -0.07516474],
        [-0.03341069, -0.10257128],
        [-0.03865023,  0.01835186],
        [-0.01121775, -0.06711451],
        [-0.01313276,  0.09430597],
        [ 0.11535889,  0.03899225]],

       [[-0.04240099,  0.01042155],
        [ 0.00500367, -0.01255119],
        [ 0.04970702, -0.07333571],
        [-0.01892604, -0.01721327],
        [-0.09991245, -0.01003497],
        [ 0.03449281, -0.03830849]]], dtype=float32)

## 5. Sliding window attention
### (Grouped Query Attention)

#### 5.1 Repating the K and V values for the GQA

In [48]:
# Repeating the Value vector 

num_value_value_heads, seq_len, head_dim = value1.shape[0], value1.shape[1], value1.shape[2]

if n_rep > 1:

    value1 = np.broadcast_to(value1[:, np.newaxis, :, :], (num_value_value_heads, n_rep, seq_len, head_dim))
    value1 =  value1.reshape(num_value_value_heads * n_rep, seq_len, head_dim)

In [49]:
value1

array([[[-0.03718761, -0.07516474],
        [-0.03341069, -0.10257128],
        [-0.03865023,  0.01835186],
        [-0.01121775, -0.06711451],
        [-0.01313276,  0.09430597],
        [ 0.11535889,  0.03899225]],

       [[-0.03718761, -0.07516474],
        [-0.03341069, -0.10257128],
        [-0.03865023,  0.01835186],
        [-0.01121775, -0.06711451],
        [-0.01313276,  0.09430597],
        [ 0.11535889,  0.03899225]],

       [[-0.04240099,  0.01042155],
        [ 0.00500367, -0.01255119],
        [ 0.04970702, -0.07333571],
        [-0.01892604, -0.01721327],
        [-0.09991245, -0.01003497],
        [ 0.03449281, -0.03830849]],

       [[-0.04240099,  0.01042155],
        [ 0.00500367, -0.01255119],
        [ 0.04970702, -0.07333571],
        [-0.01892604, -0.01721327],
        [-0.09991245, -0.01003497],
        [ 0.03449281, -0.03830849]]], dtype=float32)

In [50]:
# Repeating the Key vector 

num_key_value_heads, seq_len, head_dim = key_rotated.shape[0], key_rotated.shape[1], key_rotated.shape[2]

if n_rep > 1:

    key_rotated = np.broadcast_to(key_rotated[:, np.newaxis, :, :], (num_key_value_heads, n_rep, seq_len, head_dim))
    key_rotated =  key_rotated.reshape(num_key_value_heads * n_rep, seq_len, head_dim)

In [51]:
key_rotated

array([[[-0.068159  ,  0.09239922],
        [-0.04600099, -0.08495446],
        [ 0.11034887,  0.02561084],
        [-0.04963051,  0.08281887],
        [ 0.04055563, -0.0557485 ],
        [-0.15242632, -0.10854888]],

       [[-0.068159  ,  0.09239922],
        [-0.04600099, -0.08495446],
        [ 0.11034887,  0.02561084],
        [-0.04963051,  0.08281887],
        [ 0.04055563, -0.0557485 ],
        [-0.15242632, -0.10854888]],

       [[-0.01655504,  0.0246703 ],
        [-0.032137  , -0.01368544],
        [ 0.01370091, -0.04274352],
        [ 0.01253712,  0.00938164],
        [-0.09656778, -0.02180969],
        [ 0.04380195,  0.0521143 ]],

       [[-0.01655504,  0.0246703 ],
        [-0.032137  , -0.01368544],
        [ 0.01370091, -0.04274352],
        [ 0.01253712,  0.00938164],
        [-0.09656778, -0.02180969],
        [ 0.04380195,  0.0521143 ]]], dtype=float32)

In [52]:
query_rotated.shape, key_rotated.shape, value1.shape

((4, 6, 2), (4, 6, 2), (4, 6, 2))

### Sliding window attention mask 

In [55]:
seq_length = src_tokenized_np.shape[0]
sliding_window_len = 3

sliding_window_mask = np.triu(np.ones((seq_length, seq_length)), k=1).astype('bool')

# print(sliding_window_mask)

for i in range(sliding_window_mask.shape[0]-1, -1, -1):

    li = i - sliding_window_len + 1

    if li > 0:

        sliding_window_mask[i][0:li] = True

sliding_window_mask

array([[False,  True,  True,  True,  True,  True],
       [False, False,  True,  True,  True,  True],
       [False, False, False,  True,  True,  True],
       [ True, False, False, False,  True,  True],
       [ True,  True, False, False, False,  True],
       [ True,  True,  True, False, False, False]])

In [56]:
attn_mask = sliding_window_mask

In [58]:
L, S = query.shape[-2], key.shape[-2]

scale_factor = 1 / np.sqrt(query.shape[-1])
attn_bias = np.zeros((L, S), dtype=query.dtype)

if attn_mask is not None:
    if attn_mask.dtype == np.bool_:
        attn_bias[attn_mask] = -np.inf
    else:
        attn_bias += attn_mask

    print("Attention bias = ")
    print(attn_bias)

Attention bias = 
[[  0. -inf -inf -inf -inf -inf]
 [  0.   0. -inf -inf -inf -inf]
 [  0.   0.   0. -inf -inf -inf]
 [-inf   0.   0.   0. -inf -inf]
 [-inf -inf   0.   0.   0. -inf]
 [-inf -inf -inf   0.   0.   0.]]


In [59]:
key_rotated_T = np.transpose(key_rotated, axes=(0, 2, 1))

attn_weight = query_rotated @ key_rotated_T * scale_factor
attn_weight += attn_bias

exp_attn_weight = np.exp(attn_weight)
sum_exp_attn_weight = np.sum(exp_attn_weight, axis=-1, keepdims=True)
softmax_attn_weight = exp_attn_weight / sum_exp_attn_weight

print("SoftMax (Scaled Dot Product of Q and K) = ")
softmax_attn_weight

SoftMax (Scaled Dot Product of Q and K) = 


array([[[1.        , 0.        , 0.        , 0.        , 0.        ,
         0.        ],
        [0.49885264, 0.50114733, 0.        , 0.        , 0.        ,
         0.        ],
        [0.33318213, 0.3334744 , 0.3333435 , 0.        , 0.        ,
         0.        ],
        [0.        , 0.3336867 , 0.33314148, 0.33317187, 0.        ,
         0.        ],
        [0.        , 0.        , 0.33353096, 0.33277047, 0.33369857,
         0.        ],
        [0.        , 0.        , 0.        , 0.33410344, 0.33363643,
         0.3322601 ]],

       [[1.        , 0.        , 0.        , 0.        , 0.        ,
         0.        ],
        [0.49902087, 0.5009791 , 0.        , 0.        , 0.        ,
         0.        ],
        [0.33346903, 0.3321865 , 0.33434448, 0.        , 0.        ,
         0.        ],
        [0.        , 0.33339387, 0.33405393, 0.3325522 , 0.        ,
         0.        ],
        [0.        , 0.        , 0.3333515 , 0.33261362, 0.3340349 ,
         0.        

In [60]:
attn_output = softmax_attn_weight @ value1

In [61]:
attn_output

array([[[-0.03718761, -0.07516474],
        [-0.03529482, -0.08889945],
        [-0.03641566, -0.05313097],
        [-0.02776214, -0.05047357],
        [-0.02100636,  0.01525696],
        [ 0.0301997 ,  0.02199629]],

       [[-0.03718761, -0.07516474],
        [-0.03529545, -0.08889484],
        [-0.03642198, -0.05300206],
        [-0.02778067, -0.0503852 ],
        [-0.02100209,  0.01529591],
        [ 0.03044323,  0.02207814]],

       [[-0.04240099,  0.01042155],
        [-0.01869046, -0.00106879],
        [ 0.00411539, -0.02516541],
        [ 0.01194573, -0.03437496],
        [-0.02300747, -0.03353863],
        [-0.0280118 , -0.02187104]],

       [[-0.04240099,  0.01042155],
        [-0.01871228, -0.00105822],
        [ 0.00410155, -0.02515661],
        [ 0.0119473 , -0.03438282],
        [-0.02304507, -0.0335216 ],
        [-0.02812188, -0.02185095]]], dtype=float32)

### 6. Post attention 

In [62]:
Wo = state_dict["layers.0.self_attn.o_proj.weight"].numpy()

self_attn_op = np.transpose(attn_output, (1, 0, 2)).copy()
self_attn_op = self_attn_op.reshape(q_len, embed_dim)

sa_output = np.matmul(self_attn_op, Wo.T)

In [63]:
sa_output, sa_output.shape

(array([[-5.4631353e-04, -4.2850545e-04, -2.7580101e-03,  3.4945257e-04,
          1.8132386e-03, -9.3881667e-05, -1.9176917e-03, -1.4552541e-05],
        [ 4.6211353e-04, -6.6357432e-04, -3.0626440e-03,  4.6732419e-04,
          3.0408024e-03, -2.9950534e-04,  3.0283217e-04, -2.4522718e-03],
        [ 1.6187821e-03, -2.3624374e-04, -1.4780506e-03,  5.6520052e-04,
          2.9295764e-03, -7.5796153e-04,  2.8198441e-03, -3.0192956e-03],
        [ 2.0560285e-03, -1.8307311e-04, -1.3614881e-03, -7.2920178e-05,
          3.3871541e-03, -1.0801419e-03,  3.4962103e-03, -3.6736529e-03],
        [ 8.1324280e-04,  8.0571848e-04,  7.3775818e-04, -1.1942638e-03,
          6.8257906e-04, -1.2898125e-03,  3.6089559e-04,  1.4434158e-03],
        [-1.1867824e-04,  7.3713245e-04,  4.5464179e-04, -3.7712425e-03,
          2.9719065e-04, -1.3826506e-03, -2.1593508e-03,  2.1981997e-03]],
       dtype=float32),
 (6, 8))

In [64]:
hidden_states = residual + sa_output

hidden_states


array([[-0.00791677, -0.02584263,  0.00398741,  0.0115725 ,  0.03327203,
         0.03731762,  0.01455585, -0.01814195],
       [-0.01259581,  0.02110402,  0.01946167,  0.0529499 ,  0.0131973 ,
         0.03599653,  0.00355262, -0.01612684],
       [-0.02932685,  0.0206643 , -0.02289083,  0.02708024, -0.01178521,
        -0.02506157,  0.03780422,  0.02202267],
       [-0.01173335, -0.0131547 ,  0.01252894,  0.00845597, -0.00156012,
        -0.00021141, -0.00358098,  0.00468036],
       [ 0.03845008, -0.01183031,  0.00744214, -0.00212165,  0.01921215,
        -0.03846469, -0.01246678,  0.02179719],
       [-0.01616384,  0.01563318,  0.00741749, -0.01362945, -0.01147801,
        -0.01370889,  0.01551612,  0.03883397]])

In [65]:
residual = hidden_states

#### 7. LayerNorm 2 

In [66]:
variance_epsilon = 1e-05

wt = state_dict["layers.0.post_attention_layernorm.weight"].numpy()

dtype = hidden_states.dtype
hidden_states = hidden_states.astype(np.float32)
variance = np.mean(hidden_states**2, axis=-1, keepdims=True)
hidden_states = hidden_states * (1/np.sqrt(variance + variance_epsilon))

hidden_state = wt * hidden_states

In [67]:
hidden_state

array([[-0.35400555, -1.1555766 ,  0.17830051,  0.51747495,  1.4877894 ,
         1.6686914 ,  0.650878  , -0.8112339 ],
       [-0.47569054,  0.79700977,  0.73498535,  1.9996945 ,  0.49840653,
         1.3594372 ,  0.13416739, -0.60904276],
       [-1.1387956 ,  0.8024186 , -0.88887715,  1.0515571 , -0.45763317,
        -0.9731697 ,  1.4679816 ,  0.8551657 ],
       [-1.2953333 , -1.4522474 ,  1.3831645 ,  0.933519  , -0.17223391,
        -0.0233397 , -0.39533207,  0.51670074],
       [ 1.6706715 , -0.5140318 ,  0.3233638 , -0.09218662,  0.8347757 ,
        -1.671306  , -0.5416866 ,  0.9470969 ],
       [-0.84970623,  0.8218103 ,  0.38992497, -0.7164776 , -0.60337996,
        -0.7206535 ,  0.81565654,  2.0414376 ]], dtype=float32)

### 8. MLP layer of mistral

In [68]:
W_down_proj = state_dict["layers.0.mlp.down_proj.weight"].numpy()

W_up_proj = state_dict["layers.0.mlp.up_proj.weight"].numpy()

W_gate_proj = state_dict["layers.0.mlp.gate_proj.weight"].numpy()

In [69]:
up_proj = np.matmul(hidden_state, W_up_proj.T)
gate_proj = np.matmul(hidden_state, W_gate_proj.T)

print("UP PROJ = ", up_proj)
print()

print("GATE PROJ = ", gate_proj)
print()

temp_proj = up_proj * gate_proj


# SilU 
temp_proj = temp_proj / (1 + np.exp(-temp_proj))

print("ACT = ", temp_proj)

down_proj = np.matmul(temp_proj, W_down_proj.T)


UP PROJ =  [[-0.02633362 -0.03738315  0.03618297  0.02019078  0.04695692  0.01560353
  -0.02826723  0.03186499]
 [-0.10240708  0.01969982 -0.03214018 -0.00515962  0.0641224   0.05243002
  -0.00151322  0.00098214]
 [-0.06928502 -0.02650194 -0.0288421   0.09493146  0.05052106  0.09632429
  -0.03608317  0.01299343]
 [-0.03276506  0.07164638 -0.02559902 -0.01827805  0.00236751  0.00413582
  -0.0561991  -0.02200339]
 [ 0.0753521   0.02986558  0.00396085 -0.083946   -0.06246215 -0.06265807
  -0.11088322 -0.03623838]
 [-0.00975075  0.02046935 -0.00375318 -0.02138426  0.00363632  0.04210124
  -0.03814134 -0.06613762]]

GATE PROJ =  [[-3.32227238e-02  1.18128814e-01 -3.43855433e-02 -1.72991175e-02
  -2.88696140e-02 -3.13251726e-02 -9.33740381e-03 -3.74927334e-02]
 [-3.84216122e-02  7.73807988e-02 -7.75982216e-02 -1.13112681e-01
  -3.36851692e-03 -5.13751172e-02  1.05518986e-04 -1.00727126e-01]
 [ 7.45595098e-02 -1.85438283e-02  7.86799844e-03  4.11682576e-02
  -1.64301712e-02 -3.12208980e-02 -3

In [70]:
hidden_state = down_proj
hidden_state

array([[-1.1603549e-04,  1.0476179e-04,  5.0112540e-05,  8.0561222e-06,
         3.7039474e-05,  6.7446716e-05,  4.9148895e-07, -2.5759321e-06],
       [ 4.9017112e-06,  1.3649859e-05,  2.6713100e-05, -9.1412177e-05,
         2.5664525e-05, -4.4378277e-05,  9.8380751e-06, -1.4420082e-05],
       [-5.1639217e-05, -1.4465520e-05, -7.2297342e-05,  1.1697599e-05,
        -6.9711867e-05,  5.1093048e-05, -5.3775126e-05, -4.8950795e-05],
       [-1.6044300e-05,  1.7578410e-05,  9.4723309e-06, -4.8006841e-05,
        -1.9136964e-05, -2.3289011e-05, -2.9307079e-05, -2.1518310e-05],
       [ 1.5083115e-05,  1.4094624e-05,  1.0383761e-04, -2.3425488e-05,
        -7.3550050e-06,  6.7495741e-05,  7.6142584e-05, -1.7058817e-06],
       [-1.2485162e-05,  2.0857897e-05,  8.2886154e-06, -4.5903553e-06,
        -2.7969083e-05, -5.5890367e-05,  2.3006802e-05, -1.7650222e-05]],
      dtype=float32)

In [71]:
# Residual connection
hidden_state = hidden_state + residual

In [72]:
hidden_state

array([[-0.0080328 , -0.02573786,  0.00403752,  0.01158056,  0.03330907,
         0.03738506,  0.01455634, -0.01814453],
       [-0.01259091,  0.02111767,  0.01948838,  0.05285849,  0.01322297,
         0.03595215,  0.00356246, -0.01614126],
       [-0.02937849,  0.02064983, -0.02296312,  0.02709194, -0.01185492,
        -0.02501048,  0.03775045,  0.02197372],
       [-0.01174939, -0.01313712,  0.01253841,  0.00840796, -0.00157926,
        -0.0002347 , -0.00361029,  0.00465884],
       [ 0.03846516, -0.01181622,  0.00754597, -0.00214508,  0.0192048 ,
        -0.03839719, -0.01239064,  0.02179549],
       [-0.01617632,  0.01565404,  0.00742577, -0.01363404, -0.01150598,
        -0.01376478,  0.01553912,  0.03881632]])

In [73]:
hidden_state.shape

(6, 8)