In [1]:
from transformers.trainer_utils import set_seed
import torch

SEED = 6
set_seed(SEED)

In [2]:
from transformers import AutoConfig, MixtralConfig


mixtral_config = MixtralConfig.from_pretrained("NickyNicky/Mixtral-TinyMistral-8x248M-Instruct_oasst2_chatML_Intel_orca_dpo_pairs_DPO_V1", num_hidden_layers = 1, use_cache = False, hidden_size = 8, num_attention_heads = 4, 
                                           output_hidden_states=True,  num_key_value_heads = 2, past_key_values = True, intermediate_size = 8, sliding_window = 3, dropout_p = 0, 
                                           
                                           num_local_experts = 4, num_experts_per_tok = 2)


mixtral_config

MixtralConfig {
  "_name_or_path": "NickyNicky/LocutusqueXFelladrin-TinyMistral248M-Instruct_oasst2_chatML_V4",
  "architectures": [
    "MixtralForCausalLM"
  ],
  "attention_dropout": 0.0,
  "bos_token_id": 1,
  "eos_token_id": 2,
  "hidden_act": "silu",
  "hidden_size": 8,
  "initializer_range": 0.02,
  "intermediate_size": 8,
  "max_position_embeddings": 32768,
  "model_type": "mixtral",
  "num_attention_heads": 4,
  "num_experts_per_tok": 2,
  "num_hidden_layers": 1,
  "num_key_value_heads": 2,
  "num_local_experts": 4,
  "output_hidden_states": true,
  "output_router_logits": false,
  "pretraining_tp": 1,
  "rms_norm_eps": 1e-06,
  "rope_theta": 10000.0,
  "router_aux_loss_coef": 0.001,
  "sliding_window": 3,
  "tie_word_embeddings": false,
  "torch_dtype": "bfloat16",
  "transformers_version": "4.38.0.dev0",
  "use_cache": false,
  "vocab_size": 32005
}

In [3]:
from transformers import AutoModel

tinymixtral = AutoModel.from_config(mixtral_config)

In [4]:
from transformers import AutoTokenizer

src_sent = "hi how are you doing"

mistal_tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")

In [5]:
tokenized_src_dict = mistal_tokenizer.encode_plus(src_sent, return_tensors='pt')
tokenized_src_dict

{'input_ids': tensor([[    1, 12014,   910,   460,   368,  2548]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1]])}

In [6]:
tokenized_src_dict

{'input_ids': tensor([[    1, 12014,   910,   460,   368,  2548]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1]])}

In [7]:
src_tokenized = tokenized_src_dict["input_ids"]
src_tokenized

tensor([[    1, 12014,   910,   460,   368,  2548]])

In [8]:
mistal_tokenizer.decode(*src_tokenized)

'<s> hi how are you doing'

In [9]:
src_tokenized.shape

torch.Size([1, 6])

In [10]:
from pprint import pprint 

pprint(tokenized_src_dict)

{'attention_mask': tensor([[1, 1, 1, 1, 1, 1]]),
 'input_ids': tensor([[    1, 12014,   910,   460,   368,  2548]])}


In [11]:
torch.ones

<function torch._VariableFunctionsClass.ones>

In [12]:
seq_length = src_tokenized.shape[1]
sliding_window_len = 3

sliding_window_mask = 1 - (torch.triu(torch.ones(seq_length, seq_length), diagonal=1))


for i in range(sliding_window_mask.shape[0]-1, -1, -1):

    li = i - sliding_window_len + 1


    if li > 0:

        sliding_window_mask[i][0:li] = 0

sliding_window_mask

tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [0., 1., 1., 1., 0., 0.],
        [0., 0., 1., 1., 1., 0.],
        [0., 0., 0., 1., 1., 1.]])

In [13]:
sliding_window_mask = sliding_window_mask.unsqueeze(0)
sliding_window_mask = sliding_window_mask.unsqueeze(0)
sliding_window_mask.shape, sliding_window_mask

(torch.Size([1, 1, 6, 6]),
 tensor([[[[1., 0., 0., 0., 0., 0.],
           [1., 1., 0., 0., 0., 0.],
           [1., 1., 1., 0., 0., 0.],
           [0., 1., 1., 1., 0., 0.],
           [0., 0., 1., 1., 1., 0.],
           [0., 0., 0., 1., 1., 1.]]]]))

In [14]:
tokenized_src_dict["attention_mask"] = sliding_window_mask

pprint(tokenized_src_dict)


{'attention_mask': tensor([[[[1., 0., 0., 0., 0., 0.],
          [1., 1., 0., 0., 0., 0.],
          [1., 1., 1., 0., 0., 0.],
          [0., 1., 1., 1., 0., 0.],
          [0., 0., 1., 1., 1., 0.],
          [0., 0., 0., 1., 1., 1.]]]]),
 'input_ids': tensor([[    1, 12014,   910,   460,   368,  2548]])}


In [15]:
tinymixtral

MixtralModel(
  (embed_tokens): Embedding(32005, 8)
  (layers): ModuleList(
    (0): MixtralDecoderLayer(
      (self_attn): MixtralSdpaAttention(
        (q_proj): Linear(in_features=8, out_features=8, bias=False)
        (k_proj): Linear(in_features=8, out_features=4, bias=False)
        (v_proj): Linear(in_features=8, out_features=4, bias=False)
        (o_proj): Linear(in_features=8, out_features=8, bias=False)
        (rotary_emb): MixtralRotaryEmbedding()
      )
      (block_sparse_moe): MixtralSparseMoeBlock(
        (gate): Linear(in_features=8, out_features=4, bias=False)
        (experts): ModuleList(
          (0-3): 4 x MixtralBlockSparseTop2MLP(
            (w1): Linear(in_features=8, out_features=8, bias=False)
            (w2): Linear(in_features=8, out_features=8, bias=False)
            (w3): Linear(in_features=8, out_features=8, bias=False)
            (act_fn): SiLU()
          )
        )
      )
      (input_layernorm): MixtralRMSNorm()
      (post_attention_layer

In [16]:
output = tinymixtral(**tokenized_src_dict)

###########################################################################
LLAMA DECODER FWD START

Attention mask =  tensor([[[[ 0.0000e+00, -3.4028e+38, -3.4028e+38, -3.4028e+38, -3.4028e+38,
           -3.4028e+38],
          [ 0.0000e+00,  0.0000e+00, -3.4028e+38, -3.4028e+38, -3.4028e+38,
           -3.4028e+38],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00, -3.4028e+38, -3.4028e+38,
           -3.4028e+38],
          [-3.4028e+38,  0.0000e+00,  0.0000e+00,  0.0000e+00, -3.4028e+38,
           -3.4028e+38],
          [-3.4028e+38, -3.4028e+38,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           -3.4028e+38],
          [-3.4028e+38, -3.4028e+38, -3.4028e+38,  0.0000e+00,  0.0000e+00,
            0.0000e+00]]]])

Input (hidden states) =  tensor([[[-0.0205,  0.0033, -0.0227,  0.0090, -0.0157, -0.0196,  0.0259,
           0.0032],
         [ 0.0156,  0.0015,  0.0287, -0.0089,  0.0089,  0.0155, -0.0087,
          -0.0101],
         [-0.0133, -0.0164,  0.0127, -0.0108,  0.0175, -0.00