## Mixtral 8X7B 

Tiny Mixtral is used for checking the computations (Small model, same architercture)

In [1]:
from transformers.trainer_utils import set_seed
import torch
SEED = 6
set_seed(SEED)

In [2]:
from transformers import AutoConfig, MixtralConfig


mixtral_config = MixtralConfig.from_pretrained("NickyNicky/Mixtral-TinyMistral-8x248M-Instruct_oasst2_chatML_Intel_orca_dpo_pairs_DPO_V1", num_hidden_layers = 1, use_cache = False, hidden_size = 8, num_attention_heads = 4, 
                                           output_hidden_states=True,  num_key_value_heads = 2, past_key_values = True, intermediate_size = 8, sliding_window = 3, dropout_p = 0, 
                                           
                                           num_local_experts = 4, num_experts_per_tok = 2)

mixtral_config

MixtralConfig {
  "_name_or_path": "NickyNicky/LocutusqueXFelladrin-TinyMistral248M-Instruct_oasst2_chatML_V4",
  "architectures": [
    "MixtralForCausalLM"
  ],
  "attention_dropout": 0.0,
  "bos_token_id": 1,
  "eos_token_id": 2,
  "hidden_act": "silu",
  "hidden_size": 8,
  "initializer_range": 0.02,
  "intermediate_size": 8,
  "max_position_embeddings": 32768,
  "model_type": "mixtral",
  "num_attention_heads": 4,
  "num_experts_per_tok": 2,
  "num_hidden_layers": 1,
  "num_key_value_heads": 2,
  "num_local_experts": 4,
  "output_hidden_states": true,
  "output_router_logits": false,
  "pretraining_tp": 1,
  "rms_norm_eps": 1e-06,
  "rope_theta": 10000.0,
  "router_aux_loss_coef": 0.001,
  "sliding_window": 3,
  "tie_word_embeddings": false,
  "torch_dtype": "bfloat16",
  "transformers_version": "4.38.0.dev0",
  "use_cache": false,
  "vocab_size": 32005
}

In [3]:
from transformers import AutoModel

tinymixtral = AutoModel.from_config(mixtral_config)

In [4]:
from transformers import AutoTokenizer

src_sent = "hi how are you doing"

mistal_tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")

In [5]:
tokenized_src_dict = mistal_tokenizer.encode_plus(src_sent, return_tensors='pt')
tokenized_src_dict

{'input_ids': tensor([[    1, 12014,   910,   460,   368,  2548]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1]])}

In [6]:
tokenized_src_dict = mistal_tokenizer.encode_plus(src_sent, return_tensors='pt')
tokenized_src_dict

{'input_ids': tensor([[    1, 12014,   910,   460,   368,  2548]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1]])}

In [7]:
src_tokenized = tokenized_src_dict["input_ids"]
src_tokenized

tensor([[    1, 12014,   910,   460,   368,  2548]])

In [8]:
mistal_tokenizer.decode(*src_tokenized)

'<s> hi how are you doing'

In [9]:
seq_length = src_tokenized.shape[1]
sliding_window_len = 3

sliding_window_mask = 1 - (torch.triu(torch.ones(seq_length, seq_length), diagonal=1))


for i in range(sliding_window_mask.shape[0]-1, -1, -1):

    li = i - sliding_window_len + 1


    if li > 0:

        sliding_window_mask[i][0:li] = 0

sliding_window_mask

tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [0., 1., 1., 1., 0., 0.],
        [0., 0., 1., 1., 1., 0.],
        [0., 0., 0., 1., 1., 1.]])

In [10]:
sliding_window_mask = sliding_window_mask.unsqueeze(0)
sliding_window_mask = sliding_window_mask.unsqueeze(0)
sliding_window_mask.shape, sliding_window_mask

(torch.Size([1, 1, 6, 6]),
 tensor([[[[1., 0., 0., 0., 0., 0.],
           [1., 1., 0., 0., 0., 0.],
           [1., 1., 1., 0., 0., 0.],
           [0., 1., 1., 1., 0., 0.],
           [0., 0., 1., 1., 1., 0.],
           [0., 0., 0., 1., 1., 1.]]]]))

In [11]:
from pprint import pprint 

tokenized_src_dict["attention_mask"] = sliding_window_mask

pprint(tokenized_src_dict)


{'attention_mask': tensor([[[[1., 0., 0., 0., 0., 0.],
          [1., 1., 0., 0., 0., 0.],
          [1., 1., 1., 0., 0., 0.],
          [0., 1., 1., 1., 0., 0.],
          [0., 0., 1., 1., 1., 0.],
          [0., 0., 0., 1., 1., 1.]]]]),
 'input_ids': tensor([[    1, 12014,   910,   460,   368,  2548]])}


### **SLIDING WINDOW ATTENTION** and **GROUPED-QUERY ATTENTION**

In [12]:
# Hyperparameters of the model we loaded

# Hidden size
embed_dim = 8

num_heads = 4

# For grouped query attention 
# No.of query heds 
n_heads_q = num_heads

# No.of key and values heads 
n_kv_heads = 2

# Sequence length of the input passed 
seq_len = 6

# Dimension per head
head_dim = embed_dim // num_heads

# No.of time the K and V matrices needs to be repeated
n_rep = n_heads_q//n_kv_heads
print("n_rep = ", n_rep)


# MOE hyperparameters 
num_local_experts = 4
num_experts_per_tok = 2

n_rep =  2


In [13]:
state_dict = tinymixtral.state_dict()

### Function for getting the word embeddings

In [14]:
import torch

def look_up_table(sentence, vocab_embeds, embedding):

    for i in range(sentence.size(0)):
        for j in range(sentence.size(1)):
            
            # Get the index for the current word token index in the sequence
            word_index = sentence[i, j].item()

            if word_index < 0 or word_index >= vocab_embeds.size(0):
                raise ValueError(f"Invalid word index: {word_index}")

            # Lookup the corresponding embedding vector for the word
            embedding[i, j, :] = vocab_embeds[word_index, :]

            print(f"Word index: {word_index}, Embedding: {vocab_embeds[word_index, :]}")
    print()

    return embedding

In [15]:
def get_embedding_outputs(src_tokens, state_dict, d_model):

    src_vocab_embeds = state_dict["embed_tokens.weight"]

    src_embedding = torch.zeros(src_tokens.size(0), src_tokens.size(1), d_model)
    print("Source sentence embedding")
    src_embedding =  look_up_table(src_tokens, src_vocab_embeds, src_embedding)
    print(src_embedding.shape)


    print("Source embeddings : \n")
    print(src_embedding)

    return src_embedding


input_embeddings = get_embedding_outputs(src_tokenized, state_dict, d_model = embed_dim)

Source sentence embedding
Word index: 1, Embedding: tensor([-0.0205,  0.0033, -0.0227,  0.0090, -0.0157, -0.0196,  0.0259,  0.0032])
Word index: 12014, Embedding: tensor([ 0.0156,  0.0015,  0.0287, -0.0089,  0.0089,  0.0155, -0.0087, -0.0101])
Word index: 910, Embedding: tensor([-0.0133, -0.0164,  0.0127, -0.0108,  0.0175, -0.0043, -0.0097,  0.0199])
Word index: 460, Embedding: tensor([ 0.0042, -0.0034,  0.0021, -0.0117, -0.0083,  0.0120,  0.0334,  0.0175])
Word index: 368, Embedding: tensor([-0.0078,  0.0300, -0.0113,  0.0074,  0.0170,  0.0120, -0.0012, -0.0383])
Word index: 2548, Embedding: tensor([-0.0140,  0.0294, -0.0047,  0.0060,  0.0086, -0.0286, -0.0195,  0.0427])

torch.Size([1, 6, 8])
Source embeddings : 

tensor([[[-0.0205,  0.0033, -0.0227,  0.0090, -0.0157, -0.0196,  0.0259,
           0.0032],
         [ 0.0156,  0.0015,  0.0287, -0.0089,  0.0089,  0.0155, -0.0087,
          -0.0101],
         [-0.0133, -0.0164,  0.0127, -0.0108,  0.0175, -0.0043, -0.0097,
           0.01

In [16]:
residual = input_embeddings

### Pre-normlization in the transformer block

In [17]:
def apply_layernorm(hidden_states, wt, variance_epsilon = 1e-06):

    dtype =  hidden_states.dtype
    hidden_states = hidden_states.to(torch.float32)
    variance = hidden_states.pow(2).mean(-1, keepdim=True)
    hidden_states = hidden_states * torch.rsqrt(variance + variance_epsilon)
    op = wt * hidden_states
    return op.to(dtype)

In [18]:
hidden_state = apply_layernorm(input_embeddings, state_dict["layers.0.input_layernorm.weight"], variance_epsilon = 1e-06)

hidden_state


tensor([[[-1.1963,  0.1901, -1.3278,  0.5259, -0.9184, -1.1431,  1.5101,
           0.1872],
         [ 1.0852,  0.1033,  1.9964, -0.6169,  0.6195,  1.0772, -0.6076,
          -0.7026],
         [-0.9566, -1.1778,  0.9121, -0.7751,  1.2597, -0.3058, -0.6990,
           1.4334],
         [ 0.2787, -0.2234,  0.1394, -0.7768, -0.5505,  0.7990,  2.2188,
           1.1619],
         [-0.4003,  1.5348, -0.5801,  0.3767,  0.8682,  0.6159, -0.0632,
          -1.9615],
         [-0.6097,  1.2791, -0.2058,  0.2616,  0.3740, -1.2438, -0.8486,
           1.8599]]])

In [19]:
state_dict.keys()

odict_keys(['embed_tokens.weight', 'layers.0.self_attn.q_proj.weight', 'layers.0.self_attn.k_proj.weight', 'layers.0.self_attn.v_proj.weight', 'layers.0.self_attn.o_proj.weight', 'layers.0.block_sparse_moe.gate.weight', 'layers.0.block_sparse_moe.experts.0.w1.weight', 'layers.0.block_sparse_moe.experts.0.w2.weight', 'layers.0.block_sparse_moe.experts.0.w3.weight', 'layers.0.block_sparse_moe.experts.1.w1.weight', 'layers.0.block_sparse_moe.experts.1.w2.weight', 'layers.0.block_sparse_moe.experts.1.w3.weight', 'layers.0.block_sparse_moe.experts.2.w1.weight', 'layers.0.block_sparse_moe.experts.2.w2.weight', 'layers.0.block_sparse_moe.experts.2.w3.weight', 'layers.0.block_sparse_moe.experts.3.w1.weight', 'layers.0.block_sparse_moe.experts.3.w2.weight', 'layers.0.block_sparse_moe.experts.3.w3.weight', 'layers.0.input_layernorm.weight', 'layers.0.post_attention_layernorm.weight', 'norm.weight'])

### Function to get the Q, K and V vectors from input embeddings 

In [20]:
def get_qkv(hidden_state ,Wq, Wk, Wv):


    q_matmul = hidden_state@Wq.T
    k_matmul = hidden_state@Wk.T
    v_matmul = hidden_state@Wv.T

    return q_matmul, k_matmul, v_matmul
    

In [21]:
Wq = state_dict["layers.0.self_attn.q_proj.weight"]
Wk = state_dict["layers.0.self_attn.k_proj.weight"]
Wv = state_dict["layers.0.self_attn.v_proj.weight"]
query, key, value = get_qkv(hidden_state ,Wq, Wk, Wv)

In [22]:
query, key, value

(tensor([[[-2.4317e-02,  3.0005e-03,  1.0064e-02, -1.0727e-01, -3.9087e-02,
            1.3240e-03, -5.9916e-02,  8.2977e-02],
          [-1.2701e-02,  6.1144e-03, -1.4088e-02,  9.4415e-02,  8.5527e-02,
           -4.3866e-02,  4.6483e-02, -3.9903e-02],
          [ 9.6255e-03,  1.2447e-01,  6.6383e-02,  3.3577e-03, -2.5601e-02,
            1.2348e-04,  9.0148e-02, -6.5317e-02],
          [ 3.0088e-02,  6.6576e-02,  8.3207e-03, -6.3840e-02, -1.6522e-02,
           -8.4684e-02,  1.7307e-02,  2.8523e-02],
          [-4.5867e-02, -2.2584e-02, -6.2673e-02,  5.3766e-02,  1.0826e-01,
           -2.7717e-02, -5.8768e-02,  9.9162e-02],
          [ 1.8099e-02,  6.8153e-02,  6.8070e-02,  4.5351e-02, -4.9041e-02,
           -3.6117e-02, -4.7914e-02, -8.1207e-02]]]),
 tensor([[[ 0.0002, -0.0163, -0.0740,  0.1318],
          [ 0.0107,  0.0231,  0.1412, -0.1121],
          [-0.0582, -0.0852, -0.0010, -0.0617],
          [ 0.0354, -0.0031,  0.0278,  0.0185],
          [ 0.0854,  0.0305,  0.0678, -0.02

### Rotary postional embeddings fucntions

In [23]:
def get_sin_cos(dim, seq_len, max_seq_len, base = 10000):

    inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim))

    t = torch.arange(max_seq_len, dtype=torch.int64).type_as(inv_freq)

    freqs = torch.outer(t, inv_freq)
    
    # Uses a different permutation in order to obtain the same calculation
    emb = torch.cat((freqs, freqs), dim=-1)    

    return  emb.cos()[:seq_len], emb.sin()[:seq_len]


def rotate_half(x):

    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]

    # print(x1.shape, x2.shape)
    # x1 = x[ : , : x.shape[-1] // 2]
    # x2 = x[ : , x.shape[-1] // 2 :]

    return torch.cat((-x2, x1), dim=-1)

def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):

    cos = cos.unsqueeze(unsqueeze_dim)
    sin = sin.unsqueeze(unsqueeze_dim)

    # print("HALF ROT SHAPES = ")
    # print((rotate_half(q).shape,   sin.shape))
    # print((rotate_half(k).shape , sin.shape))

    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)

    return q_embed, k_embed

In [24]:
query.shape, key.shape, value.shape

(torch.Size([1, 6, 8]), torch.Size([1, 6, 4]), torch.Size([1, 6, 4]))

In [25]:
bsz, q_len, _ = query.shape

query = query.view(bsz, q_len, n_heads_q, head_dim).transpose(1, 2)
key = key.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)
value = value.view(bsz, q_len, n_kv_heads, head_dim).transpose(1, 2)

In [26]:
query.shape, key.shape, value.shape

(torch.Size([1, 4, 6, 2]), torch.Size([1, 2, 6, 2]), torch.Size([1, 2, 6, 2]))

In [27]:
def get_Rope(query, key, head_dim, seq_len,  num_heads):

    cos, sin = get_sin_cos(dim = head_dim, seq_len = seq_len, max_seq_len = 2048, base = 10000)

    cos = cos.unsqueeze(0)
    sin = sin.unsqueeze(0)

    q_rotated, k_rotated = apply_rotary_pos_emb(query, key, cos, sin, unsqueeze_dim=1)

    return q_rotated, k_rotated
    

In [28]:
# Rotating the query and key before self attention

query_rotated, key_rotated = get_Rope(query, key, head_dim, seq_len,  num_heads)

In [29]:
query_rotated, key_rotated

(tensor([[[[-0.0243,  0.0030],
           [-0.0120, -0.0074],
           [-0.1172, -0.0430],
           [-0.0392, -0.0617],
           [ 0.0129,  0.0495],
           [ 0.0705,  0.0020]],
 
          [[ 0.0101, -0.1073],
           [-0.0871,  0.0392],
           [-0.0307,  0.0590],
           [ 0.0008,  0.0644],
           [ 0.0817,  0.0123],
           [ 0.0628, -0.0524]],
 
          [[-0.0391,  0.0013],
           [ 0.0831,  0.0483],
           [ 0.0105, -0.0233],
           [ 0.0283,  0.0815],
           [-0.0917, -0.0638],
           [-0.0485,  0.0368]],
 
          [[-0.0599,  0.0830],
           [ 0.0587,  0.0176],
           [ 0.0219,  0.1092],
           [-0.0212, -0.0258],
           [ 0.1135, -0.0203],
           [-0.0915,  0.0229]]]]),
 tensor([[[[ 0.0002, -0.0163],
           [-0.0137,  0.0215],
           [ 0.1017, -0.0174],
           [-0.0346,  0.0080],
           [-0.0327, -0.0845],
           [-0.0124,  0.0015]],
 
          [[-0.0740,  0.1318],
           [ 0.1706,  0

In [30]:
value

tensor([[[[ 0.0112, -0.0299],
          [ 0.0384,  0.0472],
          [-0.0358,  0.0217],
          [ 0.0085,  0.1030],
          [-0.0063, -0.1437],
          [ 0.0063, -0.0075]],

         [[ 0.0317, -0.0445],
          [-0.0398,  0.0336],
          [ 0.0193,  0.0889],
          [-0.0138,  0.0682],
          [-0.0108, -0.0527],
          [ 0.0422,  0.0699]]]])

### Fucntion to repeat the K and V matrices for GQA (Grouped Query Attention)

In [31]:
def repeat_kv(x, n_rep):

    bsz, num_key_value_heads, seq_len, head_dim = x.shape

    if n_rep == 1:
        return x
    
    x = x[:, :, None, :, :].expand(bsz, num_key_value_heads, n_rep, seq_len, head_dim)
    return x.reshape(bsz, num_key_value_heads * n_rep, seq_len, head_dim)


In [32]:
key_rotated = repeat_kv(key_rotated, n_rep)
value = repeat_kv(value, n_rep)

In [33]:
query.shape, key_rotated.shape, value.shape

(torch.Size([1, 4, 6, 2]), torch.Size([1, 4, 6, 2]), torch.Size([1, 4, 6, 2]))

### Self attention (GQA + Sliding window attention)

#### The mask is passed as required by the sliding window attention

In [34]:
import math

def self_attention_rope(query, key, value, attn_mask = None, scale = None, is_causal=False):

    L, S = query.size(-2), key.size(-2)

    scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
    attn_bias = torch.zeros(L, S, dtype=query.dtype)

    if is_causal:
        assert attn_mask is None
        temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
        attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
        attn_bias.to(query.dtype)

    
    if attn_mask is not None:
        if attn_mask.dtype == torch.bool:
            attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
        else:
            attn_bias += attn_mask



    # (bsz, num_heads, tgt_len, head_dim) @ (bsz, num_heads, head_dim, tgt_len) -> (bsz, num_heads, tgt_len, tgt_len) 
    attn_weight = query @ key.transpose(-2, -1) * scale_factor
    attn_weight += attn_bias

    # (bsz, num_heads, tgt_len, tgt_len) 
    attn_weight = torch.softmax(attn_weight, dim=-1)


    sum_last_dim = attn_weight.sum(dim=-1)
    tolerance = 1e-6  
    assert torch.allclose(sum_last_dim, torch.ones_like(sum_last_dim), atol=tolerance), "Attention weights sum is not approximately equal to 1"


    # # (bsz, num_heads, tgt_len, tgt_len) @ (bsz, num_heads, tgt_len, head_dim) -> (bsz, num_heads, tgt_len, head_dim) 
    attn_output = attn_weight @ value

    print("ATTEN OUTPUT = ", attn_output)

    return attn_output


### Creating the attention mask for SWA

In [35]:
seq_length = src_tokenized.shape[1]
sliding_window_len = 3

sliding_window_mask =  (1 - torch.triu(torch.ones(seq_length, seq_length), diagonal=1)).bool()

# print(sliding_window_mask)

for i in range(sliding_window_mask.shape[0]-1, -1, -1):

    li = i - sliding_window_len + 1

    if li > 0:

        sliding_window_mask[i][0:li] = False

sliding_window_mask

tensor([[ True, False, False, False, False, False],
        [ True,  True, False, False, False, False],
        [ True,  True,  True, False, False, False],
        [False,  True,  True,  True, False, False],
        [False, False,  True,  True,  True, False],
        [False, False, False,  True,  True,  True]])

In [36]:
self_attn_op = self_attention_rope(query_rotated, key_rotated, value, attn_mask = sliding_window_mask, is_causal=False)

ATTEN OUTPUT =  tensor([[[[ 0.0112, -0.0299],
          [ 0.0248,  0.0086],
          [ 0.0047,  0.0130],
          [ 0.0037,  0.0573],
          [-0.0112, -0.0062],
          [ 0.0028, -0.0160]],

         [[ 0.0112, -0.0299],
          [ 0.0248,  0.0087],
          [ 0.0047,  0.0130],
          [ 0.0037,  0.0573],
          [-0.0113, -0.0062],
          [ 0.0028, -0.0162]],

         [[ 0.0317, -0.0445],
          [-0.0043, -0.0052],
          [ 0.0037,  0.0260],
          [-0.0115,  0.0635],
          [-0.0018,  0.0346],
          [ 0.0059,  0.0284]],

         [[ 0.0317, -0.0445],
          [-0.0042, -0.0053],
          [ 0.0037,  0.0258],
          [-0.0114,  0.0636],
          [-0.0017,  0.0350],
          [ 0.0058,  0.0284]]]])


In [37]:
self_attn_op.shape

torch.Size([1, 4, 6, 2])

In [38]:
def out_proj_self_attn(self_attn_op, W, embed_dim):

    self_attn_op = self_attn_op.transpose(1, 2).contiguous()
    self_attn_op = self_attn_op.reshape(bsz, q_len, embed_dim)
    return self_attn_op@W.T

In [39]:
Wo = state_dict["layers.0.self_attn.o_proj.weight"]
print(Wo.shape)

sa_output = out_proj_self_attn(self_attn_op, Wo, embed_dim)

torch.Size([8, 8])


In [40]:
sa_output

# check and verified +- 0.002 diff

tensor([[[ 8.5693e-04, -8.4662e-04,  1.4658e-03,  1.2231e-03,  1.6009e-03,
           3.5980e-04, -1.0710e-03,  1.5701e-03],
         [-8.5480e-04, -4.6229e-04, -2.8992e-04,  7.1789e-05, -4.6940e-04,
           3.3089e-04,  3.2598e-05,  1.1919e-03],
         [-4.4477e-04,  7.0532e-04, -5.9762e-04, -7.9699e-04, -1.9645e-04,
          -6.6960e-04,  6.7209e-04, -4.9457e-04],
         [-2.1304e-03,  2.0274e-03, -1.6052e-03, -2.0819e-03, -1.4704e-03,
          -1.8510e-03,  1.5074e-03, -1.6401e-03],
         [ 7.1337e-04,  5.0488e-04, -8.2343e-04, -8.1675e-04, -3.3030e-05,
          -1.8039e-04,  9.0701e-04, -1.2630e-03],
         [ 8.8190e-04,  6.5856e-05, -7.8390e-04, -5.9789e-04,  2.5963e-04,
           1.4621e-04,  8.8148e-04, -4.5098e-04]]])

In [41]:
# Residual connection 

hidden_state = sa_output + residual

In [42]:
hidden_state = apply_layernorm(hidden_state, state_dict["layers.0.post_attention_layernorm.weight"], variance_epsilon = 1e-06)

hidden_state

tensor([[[-1.1940,  0.1465, -1.2939,  0.6222, -0.8592, -1.1688,  1.5078,
           0.2904],
         [ 1.0467,  0.0726,  2.0168, -0.6245,  0.5989,  1.1229, -0.6177,
          -0.6323],
         [-1.0018, -1.1421,  0.8807, -0.8434,  1.2622, -0.3587, -0.6593,
           1.4165],
         [ 0.1351, -0.0874,  0.0323, -0.9003, -0.6377,  0.6653,  2.2816,
           1.0361],
         [-0.3569,  1.5310, -0.6104,  0.3285,  0.8500,  0.5952, -0.0165,
          -1.9877],
         [-0.5778,  1.2966, -0.2427,  0.2382,  0.3897, -1.2515, -0.8194,
           1.8612]]])

### Mixtral Feed forward network 

Generic structure of an expert

In [43]:
# Only when 'pretraining_tp == 1'

import torch.nn.functional as F
import torch.nn as nn


def mixtral_mlp(x, W1, W2, W3):
    
    w1_proj =  x@W1.T
    w3_proj =  x@W3.T

    print("W1 PROJ = ", w1_proj)
    print()

    print("W3 PROJ = ", w3_proj)
    print()

    temp_proj = w1_proj * w3_proj

    silu = nn.SiLU()

    # sigm = nn.Sigmoid()

    # sigm_proj = sigm(temp_proj)
    # sigm_proj = torch.special.expit(temp_proj)

    # temp_proj = sigm_proj*temp_proj

    temp_proj = silu(temp_proj)
    print(temp_proj.dtype, x.dtype)

    print("ACT = ",temp_proj)

    w_final_proj = temp_proj@W2.T

    print("W2 proj = ", w_final_proj)

    return w_final_proj

## Mixtrue of experts (MoE)

### Function to get the logits from gate for all the tokens in the hidden state

For every token in the sequence 

In [44]:
def get_logits(hidden_states, W_logits):

    hidden_dim = hidden_states.shape[2]

    hidden_states = hidden_states.view(-1, hidden_dim)

    logits = hidden_states@W_logits.T

    return logits 
  


W_logit = state_dict["layers.0.block_sparse_moe.gate.weight"]
logits = get_logits(hidden_state, W_logit)

logits


tensor([[-0.0032, -0.0482,  0.0332, -0.0160],
        [-0.0102,  0.0712, -0.0517,  0.0003],
        [ 0.0885, -0.0172, -0.0019, -0.0231],
        [-0.0117,  0.0283,  0.0430, -0.0417],
        [-0.0454,  0.0197, -0.0536,  0.0346],
        [ 0.0254, -0.0019,  0.0736, -0.0205]])

#### Normalised logits 

In [45]:
logits_norm = torch.softmax(logits, dim=1)
logits_norm

tensor([[0.2512, 0.2402, 0.2605, 0.2480],
        [0.2466, 0.2675, 0.2366, 0.2492],
        [0.2697, 0.2427, 0.2464, 0.2412],
        [0.2458, 0.2559, 0.2597, 0.2386],
        [0.2414, 0.2576, 0.2394, 0.2615],
        [0.2514, 0.2446, 0.2638, 0.2401]])

### Selecting the top 2 experts per token from the 4 of them 

These can be changed (hyperparameters)

In [46]:
routing_weights, selected_experts = torch.topk(logits_norm, num_experts_per_tok, dim=-1)

In [47]:
selected_experts, routing_weights

(tensor([[2, 0],
         [1, 3],
         [0, 2],
         [2, 1],
         [3, 1],
         [2, 0]]),
 tensor([[0.2605, 0.2512],
         [0.2675, 0.2492],
         [0.2697, 0.2464],
         [0.2597, 0.2559],
         [0.2615, 0.2576],
         [0.2638, 0.2514]]))

#### Normalising the weights for weighted sum

In [48]:
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
routing_weights

tensor([[0.5091, 0.4909],
        [0.5177, 0.4823],
        [0.5226, 0.4774],
        [0.5037, 0.4963],
        [0.5037, 0.4963],
        [0.5121, 0.4879]])

In [49]:
# Intialising the torch vectors for final state
final_hidden_states = torch.zeros(
            (bsz * seq_len, embed_dim), dtype=hidden_state.dtype
        )

In [50]:
final_hidden_states.shape

torch.Size([6, 8])

In [51]:
# One hot encoded vectors for each of the experts (for all 4 of them)

expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=num_local_experts).permute(2, 1, 0)

In [52]:
expert_mask

tensor([[[0, 0, 1, 0, 0, 0],
         [1, 0, 0, 0, 0, 1]],

        [[0, 1, 0, 0, 0, 0],
         [0, 0, 0, 1, 1, 0]],

        [[1, 0, 0, 1, 0, 1],
         [0, 0, 1, 0, 0, 0]],

        [[0, 0, 0, 0, 1, 0],
         [0, 1, 0, 0, 0, 0]]])

In [53]:
state_dict.keys()

odict_keys(['embed_tokens.weight', 'layers.0.self_attn.q_proj.weight', 'layers.0.self_attn.k_proj.weight', 'layers.0.self_attn.v_proj.weight', 'layers.0.self_attn.o_proj.weight', 'layers.0.block_sparse_moe.gate.weight', 'layers.0.block_sparse_moe.experts.0.w1.weight', 'layers.0.block_sparse_moe.experts.0.w2.weight', 'layers.0.block_sparse_moe.experts.0.w3.weight', 'layers.0.block_sparse_moe.experts.1.w1.weight', 'layers.0.block_sparse_moe.experts.1.w2.weight', 'layers.0.block_sparse_moe.experts.1.w3.weight', 'layers.0.block_sparse_moe.experts.2.w1.weight', 'layers.0.block_sparse_moe.experts.2.w2.weight', 'layers.0.block_sparse_moe.experts.2.w3.weight', 'layers.0.block_sparse_moe.experts.3.w1.weight', 'layers.0.block_sparse_moe.experts.3.w2.weight', 'layers.0.block_sparse_moe.experts.3.w3.weight', 'layers.0.input_layernorm.weight', 'layers.0.post_attention_layernorm.weight', 'norm.weight'])

### Storing the W1, W2 and W3 weights for each of the experts 

To compute the outputs for the tokens assigned for specific feedforwad units 

In [54]:
# 1 - w1, w2, w3
# 2 - w1, w2, w3
# 3 - w1, w2, w3
# 4 - w1, w2, w3

In [55]:

exp_prefix = "layers.0.block_sparse_moe.experts." 

expert_weights = []

for i in range(num_local_experts):

    exp_wt_pref = exp_prefix + str(i)

    temp_wts = {"w1":None, "w2":None, "w3":None}

    for j in range(1,4):

        exp_wt_pref_new = exp_wt_pref + ".w" + str(j) + ".weight"

        temp_wts["w" + str(j)] = state_dict[exp_wt_pref_new]

    expert_weights.append(temp_wts)



len(expert_weights)

4

In [56]:
hidden_states_moe = hidden_state.view(-1, embed_dim)

## Computing the feedforwad outputs for the **experts assigned for each token**

In [57]:
for expert_idx in range(num_local_experts):

    expert_layer_wts = expert_weights[expert_idx]

    idx, top_x = torch.where(expert_mask[expert_idx])

    if top_x.shape[0] == 0:
        continue

    # in torch it is faster to index using lists than torch tensors
    top_x_list = top_x.tolist()
    idx_list = idx.tolist()

    # Index the correct hidden states and compute the expert hidden state for
    # the current expert. We need to make sure to multiply the output hidden
    # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
    current_state = hidden_states_moe[None, top_x_list].reshape(-1, embed_dim)

    current_hidden_states = mixtral_mlp(current_state, expert_layer_wts["w1"], expert_layer_wts["w2"], expert_layer_wts["w3"]) * routing_weights[top_x_list, idx_list, None]

    # However `index_add_` only support torch tensors for indexing so we'll use
    # the `top_x` tensor here.
    final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_state.dtype))


final_hidden_states = final_hidden_states.reshape(bsz, seq_len, embed_dim)


W1 PROJ =  tensor([[ 0.0018, -0.0021,  0.0020,  0.0039, -0.0791, -0.0755, -0.0111,  0.0611],
        [-0.0258, -0.0187, -0.0983,  0.0030,  0.0440,  0.0357, -0.0696, -0.0077],
        [-0.0298, -0.0790, -0.0499, -0.0019, -0.0149, -0.0616, -0.0725, -0.0390]])

W3 PROJ =  tensor([[-0.0082, -0.0883, -0.0946,  0.0140, -0.0667, -0.1311,  0.0647,  0.0161],
        [ 0.0225, -0.0679, -0.0663, -0.0125,  0.0879,  0.0320, -0.0503,  0.0111],
        [-0.0190,  0.0059, -0.0476, -0.0095,  0.0509, -0.0089, -0.1948,  0.0007]])

torch.float32 torch.float32
ACT =  tensor([[-7.2249e-06,  9.2548e-05, -9.2995e-05,  2.7197e-05,  2.6439e-03,
          4.9699e-03, -3.5835e-04,  4.9337e-04],
        [-2.9021e-04,  6.3659e-04,  3.2708e-03, -1.8742e-05,  1.9381e-03,
          5.7198e-04,  1.7519e-03, -4.2701e-05],
        [ 2.8234e-04, -2.3260e-04,  1.1896e-03,  9.1454e-06, -3.7784e-04,
          2.7400e-04,  7.1150e-03, -1.4499e-05]])
W2 proj =  tensor([[ 1.2205e-04,  1.4340e-04, -6.8711e-05, -1.5961e-04, -5.37

In [58]:
final_hidden_states

# Small perturbation due to SiLU floating point precision

tensor([[[ 1.0978e-05,  6.8431e-05, -7.7291e-05,  3.8363e-05,  9.4358e-06,
          -9.8680e-05, -7.9507e-05, -4.6135e-05],
         [-1.4491e-05,  3.9965e-05, -1.7774e-05,  4.8831e-05, -3.3206e-05,
          -2.8912e-05, -5.6277e-05,  4.2447e-05],
         [ 9.2201e-05,  9.3103e-05, -6.7900e-05, -4.8041e-05, -8.7071e-05,
           1.0794e-04, -2.2923e-05,  1.5988e-05],
         [-1.0395e-05, -2.9303e-05, -3.3655e-05,  6.7940e-05, -6.5514e-05,
          -7.1542e-05, -3.9455e-05, -3.6377e-05],
         [ 1.0742e-05, -3.4113e-05,  1.7477e-05,  6.5124e-05,  2.7828e-06,
          -5.6000e-06, -4.6424e-05,  2.2704e-06],
         [-2.7409e-05, -4.4076e-05, -6.3579e-05, -6.5281e-05,  9.7286e-05,
           6.2610e-05,  6.7082e-05, -1.5571e-04]]])

In [59]:
# Residual connection 

hidden_state = final_hidden_states + residual

In [60]:
hidden_state

# +- 0.01 floating point error due to the SiLU activation function precision 

tensor([[[-0.0205,  0.0033, -0.0228,  0.0090, -0.0157, -0.0197,  0.0258,
           0.0032],
         [ 0.0156,  0.0015,  0.0287, -0.0088,  0.0089,  0.0154, -0.0088,
          -0.0100],
         [-0.0132, -0.0163,  0.0126, -0.0108,  0.0174, -0.0041, -0.0097,
           0.0200],
         [ 0.0042, -0.0034,  0.0021, -0.0116, -0.0084,  0.0120,  0.0334,
           0.0175],
         [-0.0078,  0.0300, -0.0113,  0.0074,  0.0170,  0.0120, -0.0013,
          -0.0383],
         [-0.0140,  0.0293, -0.0048,  0.0059,  0.0087, -0.0285, -0.0194,
           0.0426]]])