## Mixtral 8X7B 

Tiny Mixtral is used for checking the computations (Small model, same architercture)

In [1]:
from transformers.trainer_utils import set_seed
import torch
SEED = 6
set_seed(SEED)

In [2]:
from transformers import AutoConfig, MixtralConfig


mixtral_config = MixtralConfig.from_pretrained("NickyNicky/Mixtral-TinyMistral-8x248M-Instruct_oasst2_chatML_Intel_orca_dpo_pairs_DPO_V1", num_hidden_layers = 1, use_cache = False, hidden_size = 8, num_attention_heads = 4, 
                                           output_hidden_states=True,  num_key_value_heads = 2, past_key_values = True, intermediate_size = 8, sliding_window = 3, dropout_p = 0, 
                                           
                                           num_local_experts = 4, num_experts_per_tok = 2)

mixtral_config

MixtralConfig {
  "_name_or_path": "NickyNicky/LocutusqueXFelladrin-TinyMistral248M-Instruct_oasst2_chatML_V4",
  "architectures": [
    "MixtralForCausalLM"
  ],
  "attention_dropout": 0.0,
  "bos_token_id": 1,
  "eos_token_id": 2,
  "hidden_act": "silu",
  "hidden_size": 8,
  "initializer_range": 0.02,
  "intermediate_size": 8,
  "max_position_embeddings": 32768,
  "model_type": "mixtral",
  "num_attention_heads": 4,
  "num_experts_per_tok": 2,
  "num_hidden_layers": 1,
  "num_key_value_heads": 2,
  "num_local_experts": 4,
  "output_hidden_states": true,
  "output_router_logits": false,
  "pretraining_tp": 1,
  "rms_norm_eps": 1e-06,
  "rope_theta": 10000.0,
  "router_aux_loss_coef": 0.001,
  "sliding_window": 3,
  "tie_word_embeddings": false,
  "torch_dtype": "bfloat16",
  "transformers_version": "4.38.0.dev0",
  "use_cache": false,
  "vocab_size": 32005
}

In [3]:
from transformers import AutoModel

tinymixtral = AutoModel.from_config(mixtral_config)

In [4]:
from transformers import AutoTokenizer

src_sent = "hi how are you doing"

mistal_tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")

In [5]:
tokenized_src_dict = mistal_tokenizer.encode_plus(src_sent, return_tensors='pt')
tokenized_src_dict

{'input_ids': tensor([[    1, 12014,   910,   460,   368,  2548]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1]])}

In [6]:
tokenized_src_dict = mistal_tokenizer.encode_plus(src_sent, return_tensors='pt')
tokenized_src_dict

{'input_ids': tensor([[    1, 12014,   910,   460,   368,  2548]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1]])}

In [7]:
src_tokenized = tokenized_src_dict["input_ids"]
src_tokenized

tensor([[    1, 12014,   910,   460,   368,  2548]])

In [8]:
mistal_tokenizer.decode(*src_tokenized)

'<s> hi how are you doing'

In [9]:
src_tokenized_np = src_tokenized.numpy()[0]
src_tokenized_np.shape

(6,)

### **SLIDING WINDOW ATTENTION** and **GROUPED-QUERY ATTENTION**

### Defining hyperparameters

In [10]:
# Hyperparameters of the model we loaded

# Hidden size
embed_dim = d_model = 8

num_heads = 4

# For grouped query attention 
# No.of query heds 
n_heads_q = num_heads

# No.of key and values heads 
n_kv_heads = 2

# Sequence length of the input passed 
seq_len = 6

# Dimension per head
head_dim = embed_dim // num_heads

# No.of time the K and V matrices needs to be repeated
n_rep = n_heads_q//n_kv_heads
print("n_rep = ", n_rep)


# MOE hyperparameters 
num_local_experts = 4
num_experts_per_tok = 2

n_rep =  2


In [11]:
state_dict = tinymixtral.state_dict()

### 1. Source token embeddings 

In [12]:
import numpy as np


src_vocab_embeds = state_dict["embed_tokens.weight"].numpy()

src_embedding = np.zeros((src_tokenized_np.shape[0], d_model))

for i in range(src_tokenized_np.shape[0]):
        word_index = src_tokenized_np[i]
        if word_index < 0 or word_index >= src_vocab_embeds.shape[0]:
            raise ValueError(f"Invalid word index: {word_index}")
        src_embedding[i, :] = src_vocab_embeds[word_index, :]

        print(f"Word index: {word_index}, Embedding: {src_vocab_embeds[word_index, :]}")
print()
print(src_embedding.shape)


Word index: 1, Embedding: [-0.02049089  0.00325612 -0.02274372  0.00900858 -0.01572997 -0.01957951
  0.02586581  0.00320597]
Word index: 12014, Embedding: [ 0.01558615  0.00148417  0.02867359 -0.00886034  0.0088985   0.0154724
 -0.00872641 -0.01009076]
Word index: 910, Embedding: [-0.0133071  -0.01638352  0.01268725 -0.01078121  0.0175229  -0.00425425
 -0.00972286  0.01993921]
Word index: 460, Embedding: [ 0.00419843 -0.0033648   0.00210035 -0.01170026 -0.00829275  0.01203581
  0.03342133  0.01750135]
Word index: 368, Embedding: [-0.00782214  0.02998792 -0.01133393  0.00735941  0.01696292  0.01203462
 -0.0012349  -0.03832581]
Word index: 2548, Embedding: [-0.01400485  0.029383   -0.0047281   0.00600842  0.00859148 -0.02857112
 -0.01949297  0.04272375]

(6, 8)


In [13]:
residual = src_embedding

### 2. Pre-normalization (RMSNorm)

In [14]:
variance_epsilon = 1e-05

wt = state_dict["layers.0.input_layernorm.weight"].numpy()

dtype = src_embedding.dtype
src_embedding = src_embedding.astype(np.float32)
variance = np.mean(src_embedding**2, axis=-1, keepdims=True)
src_embedding = src_embedding * (1/np.sqrt(variance + variance_epsilon))

hidden_state = wt * src_embedding

hidden_state, hidden_state.shape

(array([[-1.1783826 ,  0.1872517 , -1.3079375 ,  0.5180622 , -0.90459317,
         -1.1259714 ,  1.4874815 ,  0.18436776],
        [ 1.062237  ,  0.10115016,  1.9541806 , -0.60385543,  0.6064558 ,
          1.0544846 , -0.59472793, -0.687712  ],
        [-0.9351401 , -1.1513319 ,  0.8915809 , -0.7576364 ,  1.2314006 ,
         -0.29896203, -0.6832625 ,  1.4012038 ],
        [ 0.27335456, -0.21907791,  0.13675119, -0.7617898 , -0.53993076,
          0.7836371 ,  2.1760218 ,  1.1394916 ],
        [-0.39570653,  1.5170282 , -0.57336086,  0.37229788,  0.85811996,
          0.60880727, -0.06247089, -1.9388255 ],
        [-0.60453886,  1.2683583 , -0.20409523,  0.25936186,  0.37086317,
         -1.2333126 , -0.84144133,  1.8442307 ]], dtype=float32),
 (6, 8))

### 3. Getting the Q,K and V matrices for attention calcualtion

In [15]:
Wq = state_dict["layers.0.self_attn.q_proj.weight"].numpy()
Wk = state_dict["layers.0.self_attn.k_proj.weight"].numpy()
Wv = state_dict["layers.0.self_attn.v_proj.weight"].numpy()


query = np.matmul(hidden_state, Wq.T)
key = np.matmul(hidden_state, Wk.T)
value = np.matmul(hidden_state, Wv.T)

In [16]:
query, key, value

(array([[-2.39523221e-02,  2.95548211e-03,  9.91353579e-03,
         -1.05658919e-01, -3.85008045e-02,  1.30411331e-03,
         -5.90171143e-02,  8.17328617e-02],
        [-1.24330502e-02,  5.98522369e-03, -1.37904333e-02,
          9.24208090e-02,  8.37197974e-02, -4.29389700e-02,
          4.55013290e-02, -3.90602089e-02],
        [ 9.40914825e-03,  1.21670708e-01,  6.48907349e-02,
          3.28224525e-03, -2.50251535e-02,  1.20714554e-04,
          8.81221592e-02, -6.38489798e-02],
        [ 2.95085814e-02,  6.52939826e-02,  8.16039462e-03,
         -6.26104698e-02, -1.62040964e-02, -8.30524042e-02,
          1.69737469e-02,  2.79735588e-02],
        [-4.53358404e-02, -2.23228615e-02, -6.19471893e-02,
          5.31434007e-02,  1.07010707e-01, -2.73960289e-02,
         -5.80873974e-02,  9.80136469e-02],
        [ 1.79467648e-02,  6.75790384e-02,  6.74972013e-02,
          4.49689254e-02, -4.86278459e-02, -3.58133242e-02,
         -4.75102328e-02, -8.05234313e-02]], dtype=float32),

In [17]:
print("query_shape = ", query.shape)
print("key_shape = ", key.shape)
print("value_shape = ", value.shape)
print()

query_shape =  (6, 8)
key_shape =  (6, 4)
value_shape =  (6, 4)



In [18]:
q_len, _ = query.shape


print("After reshaping... \n")

query1 = np.transpose(np.reshape(query, (q_len, num_heads, head_dim)), (1, 0, 2))
key1 = np.transpose(np.reshape(key, (q_len, n_kv_heads, head_dim)), (1, 0, 2))
value1 = np.transpose(np.reshape(value, (q_len, n_kv_heads, head_dim)), (1, 0, 2))

print("query_shape = ", query1.shape)
print("key_shape = ", key1.shape)
print("value_shape = ", value1.shape)
print()

After reshaping... 

query_shape =  (4, 6, 2)
key_shape =  (2, 6, 2)
value_shape =  (2, 6, 2)



### 4. Obtaining the rotary embeddings 

#### 4.1 Pre-computing the sin and cos values

In [19]:
base = 10000
max_seq_len = 2048
dim = head_dim


inv_freq = 1.0 / (base ** (np.arange(0, dim, 2, dtype=np.float32) / dim))
t = np.arange(max_seq_len, dtype=np.float32)
freqs = np.outer(t, inv_freq)
# emb = np.concatenate((np.cos(freqs), np.sin(freqs)), axis=-1)
emb = np.concatenate((freqs,freqs), axis=-1)

cos, sin =  np.cos(emb[:seq_len]), np.sin(emb[:seq_len])

In [20]:
cos, sin

(array([[ 1.        ,  1.        ],
        [ 0.5403023 ,  0.5403023 ],
        [-0.4161468 , -0.4161468 ],
        [-0.9899925 , -0.9899925 ],
        [-0.6536436 , -0.6536436 ],
        [ 0.28366217,  0.28366217]], dtype=float32),
 array([[ 0.       ,  0.       ],
        [ 0.841471 ,  0.841471 ],
        [ 0.9092974,  0.9092974],
        [ 0.14112  ,  0.14112  ],
        [-0.7568025, -0.7568025],
        [-0.9589243, -0.9589243]], dtype=float32))

#### 4.2 Applying the rotations on the Q and K matrices 

In [21]:
# Q matrix rotation

unsqueeze_dim = 0

cos_exp = np.expand_dims(cos, axis=unsqueeze_dim)
sin_exp = np.expand_dims(sin, axis=unsqueeze_dim)


# Half rotation 
q1 = query1[..., :query1.shape[-1] // 2]
q2 = query1[..., query1.shape[-1] // 2:]
q_half_rot = np.concatenate((-q2, q1), axis=-1)


query_rotated = query1*cos_exp + q_half_rot*sin_exp


In [22]:
# K matrix rotation

unsqueeze_dim = 0

cos_exp = np.expand_dims(cos, axis=unsqueeze_dim)
sin_exp = np.expand_dims(sin, axis=unsqueeze_dim)


# Half rotation 
k1 = key1[..., :key1.shape[-1] // 2]
k2 = key1[..., key1.shape[-1] // 2:]
key_half_rot = np.concatenate((-k2, k1), axis=-1)


key_rotated = key1*cos_exp + key_half_rot*sin_exp


In [23]:
query_rotated, key_rotated

(array([[[-0.02395232,  0.00295548],
         [-0.011754  , -0.00722822],
         [-0.11455044, -0.04207717],
         [-0.03842756, -0.0604763 ],
         [ 0.01273949,  0.04890147],
         [ 0.069894  ,  0.00196003]],
 
        [[ 0.00991354, -0.10565892],
         [-0.08522043,  0.03833092],
         [-0.02998861,  0.05763908],
         [ 0.00075686,  0.06313549],
         [ 0.08071044,  0.01214494],
         [ 0.0622682 , -0.05196872]],
 
        [[-0.0385008 ,  0.00130411],
         [ 0.0813659 ,  0.04724776],
         [ 0.01030437, -0.02280554],
         [ 0.02776229,  0.07993453],
         [-0.09068024, -0.06307873],
         [-0.04813614,  0.03647154]],
 
        [[-0.05901711,  0.08173286],
         [ 0.05745251,  0.01718373],
         [ 0.02138595,  0.1066998 ],
         [-0.02075151, -0.02529828],
         [ 0.11214543, -0.02010531],
         [-0.09069273,  0.02271727]]], dtype=float32),
 array([[[ 0.00020003, -0.01606483],
         [-0.0133947 ,  0.02100287],
         [ 

In [24]:
value1

array([[[ 0.01107843, -0.02947902],
        [ 0.03758795,  0.04618904],
        [-0.03502538,  0.02122178],
        [ 0.00834949,  0.10100346],
        [-0.00625654, -0.14200555],
        [ 0.00622982, -0.00740614]],

       [[ 0.03121278, -0.04385417],
        [-0.03898475,  0.03285853],
        [ 0.01888887,  0.08686603],
        [-0.01351706,  0.06688891],
        [-0.01066491, -0.05208423],
        [ 0.04187129,  0.0693242 ]]], dtype=float32)

## 5. Sliding window attention
### (Grouped Query Attention)


#### 5.1 Repating the K and V values for the GQA


In [25]:
# Repeating the Value vector 

num_value_value_heads, seq_len, head_dim = value1.shape[0], value1.shape[1], value1.shape[2]

if n_rep > 1:

    value1 = np.broadcast_to(value1[:, np.newaxis, :, :], (num_value_value_heads, n_rep, seq_len, head_dim))
    value1 =  value1.reshape(num_value_value_heads * n_rep, seq_len, head_dim)

In [26]:
value1

array([[[ 0.01107843, -0.02947902],
        [ 0.03758795,  0.04618904],
        [-0.03502538,  0.02122178],
        [ 0.00834949,  0.10100346],
        [-0.00625654, -0.14200555],
        [ 0.00622982, -0.00740614]],

       [[ 0.01107843, -0.02947902],
        [ 0.03758795,  0.04618904],
        [-0.03502538,  0.02122178],
        [ 0.00834949,  0.10100346],
        [-0.00625654, -0.14200555],
        [ 0.00622982, -0.00740614]],

       [[ 0.03121278, -0.04385417],
        [-0.03898475,  0.03285853],
        [ 0.01888887,  0.08686603],
        [-0.01351706,  0.06688891],
        [-0.01066491, -0.05208423],
        [ 0.04187129,  0.0693242 ]],

       [[ 0.03121278, -0.04385417],
        [-0.03898475,  0.03285853],
        [ 0.01888887,  0.08686603],
        [-0.01351706,  0.06688891],
        [-0.01066491, -0.05208423],
        [ 0.04187129,  0.0693242 ]]], dtype=float32)

In [27]:
# Repeating the Key vector 

num_key_value_heads, seq_len, head_dim = key_rotated.shape[0], key_rotated.shape[1], key_rotated.shape[2]

if n_rep > 1:

    key_rotated = np.broadcast_to(key_rotated[:, np.newaxis, :, :], (num_key_value_heads, n_rep, seq_len, head_dim))
    key_rotated =  key_rotated.reshape(num_key_value_heads * n_rep, seq_len, head_dim)

In [28]:
key_rotated

array([[[ 0.00020003, -0.01606483],
        [-0.0133947 ,  0.02100287],
        [ 0.09941693, -0.01702986],
        [-0.03397126,  0.00789048],
        [-0.03232932, -0.08354992],
        [-0.01233742,  0.00147469]],

       [[ 0.00020003, -0.01606483],
        [-0.0133947 ,  0.02100287],
        [ 0.09941693, -0.01702986],
        [-0.03397126,  0.00789048],
        [-0.03232932, -0.08354992],
        [-0.01233742,  0.00147469]],

       [[-0.07288229,  0.1298013 ],
        [ 0.16702944,  0.05697173],
        [ 0.05529125,  0.02423418],
        [-0.02959396, -0.01407389],
        [-0.06547794, -0.03196203],
        [ 0.01601789, -0.02786594]],

       [[-0.07288229,  0.1298013 ],
        [ 0.16702944,  0.05697173],
        [ 0.05529125,  0.02423418],
        [-0.02959396, -0.01407389],
        [-0.06547794, -0.03196203],
        [ 0.01601789, -0.02786594]]], dtype=float32)

In [29]:
query_rotated.shape, key_rotated.shape, value1.shape

((4, 6, 2), (4, 6, 2), (4, 6, 2))

### Self attention (GQA + Sliding window attention)

#### The mask is passed as required by the sliding window attention

In [30]:
seq_length = src_tokenized_np.shape[0]
sliding_window_len = 3

sliding_window_mask = np.triu(np.ones((seq_length, seq_length)), k=1).astype('bool')

# print(sliding_window_mask)

for i in range(sliding_window_mask.shape[0]-1, -1, -1):

    li = i - sliding_window_len + 1

    if li > 0:

        sliding_window_mask[i][0:li] = True

attn_mask = sliding_window_mask

sliding_window_mask


array([[False,  True,  True,  True,  True,  True],
       [False, False,  True,  True,  True,  True],
       [False, False, False,  True,  True,  True],
       [ True, False, False, False,  True,  True],
       [ True,  True, False, False, False,  True],
       [ True,  True,  True, False, False, False]])

In [31]:
L, S = query.shape[-2], key.shape[-2]

scale_factor = 1 / np.sqrt(query.shape[-1])
attn_bias = np.zeros((L, S), dtype=query.dtype)

if attn_mask is not None:
    if attn_mask.dtype == np.bool_:
        attn_bias[attn_mask] = -np.inf
    else:
        attn_bias += attn_mask

    print("Attention bias = ")
    print(attn_bias)

Attention bias = 
[[  0. -inf -inf -inf -inf -inf]
 [  0.   0. -inf -inf -inf -inf]
 [  0.   0.   0. -inf -inf -inf]
 [-inf   0.   0.   0. -inf -inf]
 [-inf -inf   0.   0.   0. -inf]
 [-inf -inf -inf   0.   0.   0.]]


In [32]:
key_rotated_T = np.transpose(key_rotated, axes=(0, 2, 1))

attn_weight = query_rotated @ key_rotated_T * scale_factor
attn_weight += attn_bias

exp_attn_weight = np.exp(attn_weight)
sum_exp_attn_weight = np.sum(exp_attn_weight, axis=-1, keepdims=True)
softmax_attn_weight = exp_attn_weight / sum_exp_attn_weight

print("SoftMax (Scaled Dot Product of Q and K) = ")
softmax_attn_weight

SoftMax (Scaled Dot Product of Q and K) = 


array([[[1.        , 0.        , 0.        , 0.        , 0.        ,
         0.        ],
        [0.50000954, 0.49999046, 0.        , 0.        , 0.        ,
         0.        ],
        [0.333778  , 0.3337777 , 0.33244425, 0.        , 0.        ,
         0.        ],
        [0.        , 0.33335105, 0.33311126, 0.33353773, 0.        ,
         0.        ],
        [0.        , 0.        , 0.33354586, 0.3334892 , 0.33296487,
         0.        ],
        [0.        , 0.        , 0.        , 0.33327696, 0.33326936,
         0.33345369]],

       [[1.        , 0.        , 0.        , 0.        , 0.        ,
         0.        ],
        [0.49977204, 0.500228  , 0.        , 0.        , 0.        ,
         0.        ],
        [0.33335233, 0.33365232, 0.3329953 , 0.        , 0.        ,
         0.        ],
        [0.        , 0.33345747, 0.33318454, 0.33335802, 0.        ,
         0.        ],
        [0.        , 0.        , 0.33419436, 0.3329604 , 0.33284527,
         0.        

In [33]:
attn_output = softmax_attn_weight @ value1

In [34]:
attn_output

# +- 0.005 difference due to numpy precision 

array([[[ 0.01107843, -0.02947902],
        [ 0.02433294,  0.00835429],
        [ 0.00459977,  0.01263248],
        [ 0.0036475 ,  0.05615484],
        [-0.01098131, -0.00652086],
        [ 0.00277493, -0.01613358]],

       [[ 0.01107843, -0.02947902],
        [ 0.02433924,  0.00837226],
        [ 0.00457104,  0.01265093],
        [ 0.00364744,  0.05614316],
        [-0.0110077 , -0.00654352],
        [ 0.00276977, -0.01620636]],

       [[ 0.03121278, -0.04385417],
        [-0.00398575, -0.00538879],
        [ 0.00369151,  0.02532087],
        [-0.01122158,  0.06218171],
        [-0.00179356,  0.0337827 ],
        [ 0.0058829 ,  0.02801968]],

       [[ 0.03121278, -0.04385417],
        [-0.00396374, -0.00541284],
        [ 0.00370368,  0.02522594],
        [-0.01119476,  0.06221688],
        [-0.00173921,  0.033982  ],
        [ 0.00587202,  0.02799441]]], dtype=float32)

### 6. Post attention 

In [35]:
Wo = state_dict["layers.0.self_attn.o_proj.weight"].numpy()

self_attn_op = np.transpose(attn_output, (1, 0, 2)).copy()
self_attn_op = self_attn_op.reshape(q_len, embed_dim)

sa_output = np.matmul(self_attn_op, Wo.T)

In [36]:
sa_output, sa_output.shape

(array([[ 8.4407721e-04, -8.3392626e-04,  1.4437779e-03,  1.2047621e-03,
          1.5769361e-03,  3.5440541e-04, -1.0549710e-03,  1.5465862e-03],
        [-8.3491480e-04, -4.5491473e-04, -2.7456848e-04,  7.6034143e-05,
         -4.5097884e-04,  3.2300162e-04,  2.5672549e-05,  1.1732047e-03],
        [-4.3571822e-04,  6.8573118e-04, -5.8262644e-04, -7.7831530e-04,
         -1.8956563e-04, -6.5275497e-04,  6.5544032e-04, -4.8372874e-04],
        [-2.0873179e-03,  1.9857292e-03, -1.5717335e-03, -2.0375180e-03,
         -1.4400799e-03, -1.8136515e-03,  1.4754506e-03, -1.6038014e-03],
        [ 7.1967638e-04,  4.8155471e-04, -8.0577575e-04, -7.8643003e-04,
         -2.6787560e-05, -1.6236477e-04,  8.8560878e-04, -1.2233036e-03],
        [ 8.8070869e-04,  5.9564245e-05, -7.7368139e-04, -5.8852532e-04,
          2.6537999e-04,  1.4871509e-04,  8.6985884e-04, -4.4311673e-04]],
       dtype=float32),
 (6, 8))

In [37]:
hidden_states = residual + sa_output

hidden_states


array([[-0.01964682,  0.00242219, -0.02129994,  0.01021334, -0.01415303,
        -0.01922511,  0.02481084,  0.00475256],
       [ 0.01475124,  0.00102926,  0.02839903, -0.00878431,  0.00844752,
         0.0157954 , -0.00870074, -0.00891756],
       [-0.01374281, -0.01569779,  0.01210462, -0.01155952,  0.01733334,
        -0.004907  , -0.00906742,  0.01945548],
       [ 0.00211111, -0.00137907,  0.00052862, -0.01373778, -0.00973283,
         0.01022216,  0.03489678,  0.01589755],
       [-0.00710247,  0.03046947, -0.01213971,  0.00657298,  0.01693613,
         0.01187226, -0.00034929, -0.03954912],
       [-0.01312414,  0.02944256, -0.00550178,  0.00541989,  0.00885686,
        -0.02842241, -0.01862311,  0.04228063]])

In [38]:
residual = hidden_states

### 7. LayerNorm 2

In [39]:
# Layer norm

variance_epsilon = 1e-05

wt = state_dict["layers.0.post_attention_layernorm.weight"].numpy()

dtype = hidden_states.dtype
hidden_states = hidden_states.astype(np.float32)
variance = np.mean(hidden_states**2, axis=-1, keepdims=True)
hidden_states = hidden_states * (1/np.sqrt(variance + variance_epsilon))

hidden_states = wt * hidden_states


## 8. Mixtral Feed forward network 

### 8.1 Logits from the MoE layer 



In [40]:
W_logit = state_dict["layers.0.block_sparse_moe.gate.weight"].numpy()

logits = np.matmul(hidden_states, W_logit.T)

logits

array([[-0.00319363, -0.04745055,  0.0325985 , -0.01575745],
       [-0.0099739 ,  0.06966699, -0.05062584,  0.00032955],
       [ 0.08649002, -0.01674007, -0.00187917, -0.02254649],
       [-0.01136106,  0.02795726,  0.0422231 , -0.04087061],
       [-0.0448232 ,  0.01951255, -0.0529776 ,  0.03426894],
       [ 0.02523708, -0.00186044,  0.07299722, -0.02028412]],
      dtype=float32)

In [41]:
exp_x = np.exp(logits - np.max(logits, axis=1, keepdims=True))
logits_norm = exp_x / np.sum(exp_x, axis=1, keepdims=True)
logits_norm

array([[0.2512144 , 0.24033889, 0.26036876, 0.24807794],
       [0.24670537, 0.26715678, 0.23687743, 0.24926043],
       [0.26924857, 0.24284053, 0.24647631, 0.24143457],
       [0.2459378 , 0.25580028, 0.25947565, 0.23878632],
       [0.24150895, 0.2575573 , 0.23954761, 0.2613861 ],
       [0.25140253, 0.2446816 , 0.2637009 , 0.24021494]], dtype=float32)

### 8.2 Selecting the top 2 experts per token from the 4 of them 

These can be changed (hyperparameters)

In [42]:
# routing_weights, selected_experts = torch.topk(logits_norm, num_experts_per_tok, dim=-1)
num_experts_per_tok = 2

selected_experts = np.argsort(logits_norm, axis=-1)[ :, -num_experts_per_tok:]
selected_experts = np.flip(selected_experts, axis=-1)
routing_weights = np.take_along_axis(logits_norm, selected_experts, axis=-1)

In [43]:
selected_experts, routing_weights

(array([[2, 0],
        [1, 3],
        [0, 2],
        [2, 1],
        [3, 1],
        [2, 0]]),
 array([[0.26036876, 0.2512144 ],
        [0.26715678, 0.24926043],
        [0.26924857, 0.24647631],
        [0.25947565, 0.25580028],
        [0.2613861 , 0.2575573 ],
        [0.2637009 , 0.25140253]], dtype=float32))

### 8.3 Normalising the weights for weighted sum

In [44]:
routing_weights /= np.sum(routing_weights, axis=-1, keepdims=True)

routing_weights

array([[0.5089471 , 0.4910529 ],
       [0.5173274 , 0.48267257],
       [0.5220779 , 0.47792205],
       [0.5035664 , 0.49643356],
       [0.503689  , 0.49631095],
       [0.51193774, 0.4880622 ]], dtype=float32)

In [45]:
# Intialising the torch vectors for final state

final_hidden_states = np.zeros(( seq_len, embed_dim), dtype=hidden_state.dtype)

final_hidden_states.shape

(6, 8)

In [46]:
expert_mask = np.eye(num_local_experts)[selected_experts].transpose(2, 1, 0)


In [47]:
expert_mask

array([[[0., 0., 1., 0., 0., 0.],
        [1., 0., 0., 0., 0., 1.]],

       [[0., 1., 0., 0., 0., 0.],
        [0., 0., 0., 1., 1., 0.]],

       [[1., 0., 0., 1., 0., 1.],
        [0., 0., 1., 0., 0., 0.]],

       [[0., 0., 0., 0., 1., 0.],
        [0., 1., 0., 0., 0., 0.]]])

### 8.4 Storing the W1, W2 and W3 weights for each of the experts 

To compute the outputs for the tokens assigned for specific feedforwad units 

In [48]:
exp_prefix = "layers.0.block_sparse_moe.experts." 

expert_weights = []

for i in range(num_local_experts):

    exp_wt_pref = exp_prefix + str(i)

    temp_wts = {"w1":None, "w2":None, "w3":None}

    for j in range(1,4):

        exp_wt_pref_new = exp_wt_pref + ".w" + str(j) + ".weight"

        temp_wts["w" + str(j)] = state_dict[exp_wt_pref_new].numpy()

    expert_weights.append(temp_wts)



len(expert_weights)

4

In [49]:
hidden_states_moe = hidden_state

### 8.5 Computing the feedforwad outputs for the **experts assigned for each token**

In [50]:
for expert_idx in range(num_local_experts):
    
    expert_layer_wts = expert_weights[expert_idx]
    idx, top_x = np.where(expert_mask[expert_idx])

    if top_x.shape[0] == 0:
        continue

    current_state = hidden_states_moe[top_x]


    W_up_proj = expert_layer_wts["w1"]  
    W_gate_proj = expert_layer_wts["w3"]
    W_down_proj = expert_layer_wts["w2"]

    up_proj = np.matmul(current_state, W_up_proj.T)
    gate_proj = np.matmul(current_state, W_gate_proj.T)

    print("UP PROJ = ", up_proj)
    print()

    print("GATE PROJ = ", gate_proj)
    print()

    temp_proj = up_proj * gate_proj


    # SilU 
    temp_proj = temp_proj / (1 + np.exp(-temp_proj))

    print("ACT = ", temp_proj)

    down_proj = np.matmul(temp_proj, W_down_proj.T)

    current_hidden_states = down_proj



    for i, x in enumerate(top_x):
        final_hidden_states[x] += current_hidden_states[i]



UP PROJ =  [[ 2.0779315e-03 -4.1524232e-05  5.3748214e-03  4.9661640e-03
  -7.7020541e-02 -7.4964218e-02 -1.0597602e-02  5.9146583e-02]
 [-2.5118727e-02 -2.0864358e-02 -9.6453838e-02  2.0347063e-03
   4.5176938e-02  4.0917356e-02 -6.4834498e-02 -7.2856932e-03]
 [-2.9358903e-02 -7.7371515e-02 -4.8153035e-02 -9.4482111e-04
  -1.3781639e-02 -6.1157711e-02 -7.1766920e-02 -3.9092816e-02]]

GATE PROJ =  [[-0.00769947 -0.08333034 -0.09006876  0.009973   -0.0639486  -0.12828517
   0.06382238  0.01461877]
 [ 0.02198876 -0.06551665 -0.06325372 -0.0090334   0.08448258  0.03772091
  -0.04802711  0.01377582]
 [-0.01560306  0.00607488 -0.04853232 -0.01146191  0.04907023 -0.00804746
  -0.19398391  0.0003019 ]]

ACT =  [[-7.9994243e-06  1.7301171e-06 -2.4199316e-04  2.4764393e-05
   2.4687427e-03  4.8315194e-03 -3.3806774e-04  4.3251214e-04]
 [-2.7608857e-04  6.8394857e-04  3.0598375e-03 -9.1900756e-06
   1.9119739e-03  7.7231554e-04  1.5593307e-03 -5.0180675e-05]
 [ 2.2909684e-04 -2.3495614e-04  1.16

In [51]:
final_hidden_states

array([[ 3.02542794e-05,  1.35043272e-04, -1.54088426e-04,
         5.34421670e-05,  8.43894668e-06, -1.77466209e-04,
        -1.46090271e-04, -8.37543193e-05],
       [-2.19882131e-05,  8.97436694e-05, -2.92263467e-05,
         8.74870675e-05, -6.94078117e-05, -4.85683049e-05,
        -8.66210903e-05,  7.54174980e-05],
       [ 1.81148906e-04,  1.70211293e-04, -1.39261247e-04,
        -8.31647048e-05, -1.85562851e-04,  1.98851703e-04,
        -4.84626544e-05,  2.80959830e-05],
       [-1.30659228e-05, -6.35682591e-05, -6.33936070e-05,
         1.08051478e-04, -1.01529673e-04, -1.31385939e-04,
        -5.35686850e-05, -7.56333829e-05],
       [ 2.67640462e-05, -6.88876753e-05,  4.22416051e-05,
         1.30987188e-04,  3.85967724e-06, -5.55036058e-06,
        -8.86259659e-05,  1.81700739e-06],
       [-6.00403946e-05, -7.89116166e-05, -1.21364224e-04,
        -1.32052272e-04,  2.00030801e-04,  1.27674823e-04,
         1.37436858e-04, -3.09652125e-04]], dtype=float32)

In [52]:
# Residual connection 

hidden_state = final_hidden_states + residual

In [53]:
hidden_state

array([[-0.01961656,  0.00255724, -0.02145403,  0.01026679, -0.01414459,
        -0.01940257,  0.02466475,  0.0046688 ],
       [ 0.01472925,  0.001119  ,  0.0283698 , -0.00869682,  0.00837811,
         0.01574683, -0.00878736, -0.00884214],
       [-0.01356167, -0.01552758,  0.01196536, -0.01164269,  0.01714777,
        -0.00470815, -0.00911589,  0.01948358],
       [ 0.00209804, -0.00144264,  0.00046522, -0.01362973, -0.00983436,
         0.01009078,  0.03484321,  0.01582192],
       [-0.0070757 ,  0.03040059, -0.01209747,  0.00670397,  0.01693999,
         0.01186671, -0.00043791, -0.0395473 ],
       [-0.01318418,  0.02936365, -0.00562315,  0.00528784,  0.00905689,
        -0.02829473, -0.01848567,  0.04197098]])